diff --git a/CMakeLists.txt b/CMakeLists.txt index c6ddb9b3d4..0c00c04641 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -210,9 +210,9 @@ set(HEADER_SOURCES src/include/param.h src/include/channel.h src/include/nvtx_stub.h - src/include/nvtx3.hpp src/include/core.h src/include/info.h + src/include/ipcsocket.h src/include/git_version.h src/include/npkit/npkit_event.h src/include/npkit/npkit.h @@ -301,6 +301,7 @@ set(CC_SOURCES src/misc/socket.cc src/misc/param.cc src/misc/rocmwrap.cc + src/misc/ipcsocket.cc src/misc/strongstream.cc src/misc/msccl/msccl_lifecycle.cc src/misc/msccl/msccl_parser.cc @@ -463,11 +464,12 @@ foreach(target ${AMDGPU_TARGETS}) target_link_libraries(rccl PRIVATE --amdgpu-target=${target}) endforeach() +set(ENABLE_IFC 1 CACHE BOOL "Enable indirect function call") if("${HIP_COMPILER}" MATCHES "clang") find_program( hipcc_executable hipcc ) execute_process(COMMAND bash "-c" "${hipcc_executable} --version | grep 'HIP version' | awk -F\" \" '{ printf $3}' | awk -F\"-\" '{ printf $1}'" OUTPUT_VARIABLE hipcc_version_string) message(STATUS "hipcc version: ${hipcc_version_string}") - if(${hipcc_version_string} VERSION_GREATER_EQUAL "5.5.30201") + if(${hipcc_version_string} VERSION_GREATER_EQUAL "5.5.30201" AND ENABLE_IFC) add_definitions(-DUSE_INDIRECT_FUNCTION_CALL) target_compile_options(rccl PRIVATE -fvisibility=hidden) message(STATUS "Indirect function call enabled") diff --git a/makefiles/version.mk b/makefiles/version.mk index e8e7b7a952..6877b63a09 100644 --- a/makefiles/version.mk +++ b/makefiles/version.mk @@ -1,6 +1,6 @@ ##### version NCCL_MAJOR := 2 -NCCL_MINOR := 16 -NCCL_PATCH := 5 +NCCL_MINOR := 17 +NCCL_PATCH := 1 NCCL_SUFFIX := PKG_REVISION := 1 diff --git a/src/Makefile b/src/Makefile index 4753018c1b..ca5ddce466 100644 --- a/src/Makefile +++ b/src/Makefile @@ -12,7 +12,8 @@ INCEXPORTS := nccl.h nccl_net.h LIBSRCFILES := init.cc init_nvtx.cc channel.cc bootstrap.cc transport.cc enqueue.cc group.cc debug.cc proxy.cc net.cc \ misc/cudawrap.cc misc/nvmlwrap.cc misc/ibvwrap.cc misc/gdrwrap.cc \ misc/utils.cc misc/argcheck.cc misc/socket.cc misc/shmutils.cc misc/profiler.cc misc/param.cc misc/strongstream.cc \ - transport/p2p.cc transport/shm.cc transport/net.cc transport/net_socket.cc transport/net_ib.cc transport/coll_net.cc \ + misc/ipcsocket.cc \ + transport/p2p.cc transport/shm.cc transport/net.cc transport/net_socket.cc transport/net_ib.cc transport/coll_net.cc transport/nvls.cc \ collectives/sendrecv.cc collectives/all_reduce.cc collectives/all_gather.cc collectives/broadcast.cc collectives/reduce.cc collectives/reduce_scatter.cc \ graph/topo.cc graph/paths.cc graph/search.cc graph/connect.cc graph/rings.cc graph/trees.cc graph/tuning.cc graph/xml.cc @@ -62,7 +63,7 @@ ALWAYS_REBUILD: -include $(DEPFILES) $(LIBDIR)/$(LIBTARGET) $(LIBDIR)/$(STATICLIBTARGET) : $(LIBOBJ) -$(INCDIR)/nccl.h : nccl.h.in +$(INCDIR)/nccl.h : nccl.h.in ../makefiles/version.mk # NCCL_VERSION(X,Y,Z) ((X) * 10000 + (Y) * 100 + (Z)) @$(eval NCCL_VERSION := $(shell printf "%d%02d%02d" $(NCCL_MAJOR) $(NCCL_MINOR) $(NCCL_PATCH))) mkdir -p $(INCDIR) diff --git a/src/bootstrap.cc b/src/bootstrap.cc index 2a96e94a68..e542e26c87 100644 --- a/src/bootstrap.cc +++ b/src/bootstrap.cc @@ -394,6 +394,24 @@ ncclResult_t bootstrapIntraNodeAllGather(void* commState, int *ranks, int rank, return ncclSuccess; } +// IntraNode in-place Broadcast +ncclResult_t bootstrapIntraNodeBroadcast(void* commState, int *ranks, int rank, int nranks, int root, void* bcastData, int size) { + if (nranks == 1) return ncclSuccess; + TRACE(NCCL_INIT, "rank %d nranks %d root %d size %d - ENTER", rank, nranks, root, size); + + if (rank == root) { + for (int i=0; iid != -1) return ncclSuccess; int nRanks = comm->nRanks; + int nPeers = nRanks + 1 /* Collnet */ + comm->localRanks /* NVLS */; channel->id = channelId; channel->workFifoSent = 0; NCCLCHECK(ncclStrongStreamAcquireUncaptured(&comm->deviceStream)); // The extra on nRanks+1 is for collnet root (i.e. network) - channel->peers = ncclMemoryStackAlloc(&comm->memPermanent, nRanks+1); - NCCLCHECK(ncclCudaCallocAsync(&channel->devPeers, nRanks+1, comm->deviceStream.cudaStream)); + channel->peers = ncclMemoryStackAlloc(&comm->memPermanent, nPeers); + NCCLCHECK(ncclCudaCallocAsync(&channel->devPeers, nPeers, comm->deviceStream.cudaStream)); ncclCommPushCudaFree(comm, channel->devPeers); channel->ring.userRanks = ncclMemoryStackAlloc(&comm->memPermanent, nRanks); @@ -29,7 +30,7 @@ ncclResult_t initChannel(struct ncclComm* comm, int channelId) { NCCLCHECK(ncclStrongStreamRelease(ncclCudaGraphNone(), &comm->deviceStream)); - for (int r=0; r < nRanks+1; ++r) { + for (int r=0; r < nPeers; ++r) { for (int b=0; b < NCCL_MAX_CONNS; b++) { channel->peers[r].send[b].comm = comm; channel->peers[r].recv[b].comm = comm; diff --git a/src/collectives/device/all_gather.h b/src/collectives/device/all_gather.h index 8837ad836a..4bc9bc2868 100644 --- a/src/collectives/device/all_gather.h +++ b/src/collectives/device/all_gather.h @@ -112,3 +112,45 @@ struct RunWorkElement(args); } }; + +template +struct RunWorkElement { + __device__ __forceinline__ void run(ncclWorkElem *args) { + const int tid = threadIdx.x; + const int bid = args->bid; + const int nChannels = args->nChannels; + struct ncclNvls* nvls = &ncclShmem.channel.nvls; + const ssize_t chunkSize = int(args->lastChunkSize); + const ssize_t size = args->count; + const ssize_t loopSize = nChannels*chunkSize; + + const int nThreadsGather = 128; + const int nThreadsBcast = 384 + WARP_SIZE; + const int tidEndGather = nThreadsGather; + const int tidEndBcast = tidEndGather + nThreadsBcast; + + using Proto = ProtoSimple<1, 1>; + + if (tid < tidEndGather) { + // Gather + int group = (0*Proto::MaxGroupWidth) | (0<<16); + Primitives, /*Direct=*/0, Proto, 0> + prims(tid, nThreadsGather, nvls->up, NULL, NULL, args->recvbuff, args->redOpArg, group, args); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = gridOffset + bid*chunkSize; + int nelem = min(chunkSize, size-offset); + prims.gather(offset, nvls->nHeads*size, nelem, size, -1, 0); + } + } else if (tid < tidEndBcast) { + int group = (3*Proto::MaxGroupWidth) | (1<<16); + // Bcast through MC + Primitives, /*Direct=*/0, Proto, 0> + prims(tid-tidEndGather, nThreadsBcast, NULL, &nvls->down, args->sendbuff, NULL, args->redOpArg, group, args); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = gridOffset + bid*chunkSize; + int nelem = min(chunkSize, size-offset); + prims.send(offset, nelem); + } + } + } +}; diff --git a/src/collectives/device/all_reduce.h b/src/collectives/device/all_reduce.h index aa8aa55224..02634ec772 100644 --- a/src/collectives/device/all_reduce.h +++ b/src/collectives/device/all_reduce.h @@ -358,7 +358,7 @@ namespace { const int nthreads = args->nWarps*WARP_SIZE; const int bid = args->bid; const int nChannels = args->nChannels; - ncclTree *tree = (args->pad_0 == 2) ? &ncclShmem.channel.binTree : &ncclShmem.channel.tree; + ncclTree *tree = &ncclShmem.channel.tree; ssize_t chunkSize = int( Proto::Id != NCCL_PROTO_LL ? args->lastChunkSize : Proto::calcBytePerStep()/sizeof(T)); @@ -583,9 +583,9 @@ struct RunWorkElementnHeads*chunkSize; int nelem = min(direct->nHeads*chunkSize, size-offset); if (args->regUsed) { - prims.directScatter(offset, nelem, chunkSize, direct->headRank, direct->shift); + prims.directScatter(offset, nelem, chunkSize, chunkSize, direct->headRank, direct->shift); } else { - prims.scatter(offset, nelem, chunkSize, direct->headRank, direct->shift); + prims.scatter(offset, nelem, chunkSize, chunkSize, direct->headRank, direct->shift); } } } else if (tid >= tidStartReduce && direct->out != -1) { @@ -621,7 +621,7 @@ struct RunWorkElementnHeads*chunkSize; int nelem = min(direct->nHeads*chunkSize, size-offset); - prims.directGather(offset, nelem, chunkSize, direct->headRank, direct->shift); + prims.directGather(offset, nelem, chunkSize, chunkSize, direct->headRank, direct->shift); } } else if (tid >= tidStartBcast && tid < tidStartScatter && direct->out != -1) { int group = (1*Proto::MaxGroupWidth) | (0<<16); @@ -648,6 +648,65 @@ struct RunWorkElement +struct RunWorkElement { + __device__ __forceinline__ void run(ncclWorkElem *args) { + #if NCCL_NVLS_ENABLED + const int tid = threadIdx.x; + const int bid = args->bid; + const int nChannels = args->nChannels; + struct ncclNvls* nvls = &ncclShmem.channel.nvls; + const ssize_t chunkSize = int(args->lastChunkSize); + const ssize_t size = args->count; + const ssize_t loopSize = nChannels*nvls->nHeads*chunkSize; + const int nranks = ncclShmem.comm.nRanks; + const int reduceWarps = nranks <= 6 ? 6 : 4; + const int copyWarps = ((NCCL_MAX_NTHREADS/WARP_SIZE) - reduceWarps)/2; + + const int nThreadsScatter = copyWarps*WARP_SIZE; + const int nThreadsGather = (copyWarps-1)*WARP_SIZE; + const int nThreadsReduce = (reduceWarps+1)*WARP_SIZE; + const int tidEndScatter = nThreadsScatter; + const int tidEndGather = tidEndScatter + nThreadsGather; + const int tidEndReduce = tidEndGather + nThreadsReduce; + + using Proto = ProtoSimple<1, 1, COLL_UNROLL, /*NVLS=*/true>; + + if (tid < tidEndScatter) { + // Scatter + int group = (0*Proto::MaxGroupWidth) | (0<<16); + Primitives, /*Direct=*/0, Proto, 0> + prims(tid, nThreadsScatter, NULL, nvls->up, args->sendbuff, args->recvbuff, args->redOpArg, group, args); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = gridOffset + bid*nvls->nHeads*chunkSize; + int nelem = min(nvls->nHeads*chunkSize, size-offset); + prims.scatter(offset, nelem, chunkSize, chunkSize, -1, 0); + } + } else if (tid < tidEndGather) { + // Gather + int group = (2*Proto::MaxGroupWidth) | (0<<16); + Primitives, /*Direct=*/0, Proto, 0> + prims(tid-tidEndScatter, nThreadsGather, nvls->up, NULL, args->sendbuff, args->recvbuff, args->redOpArg, group, args); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = gridOffset + bid*nvls->nHeads*chunkSize; + int nelem = min(nvls->nHeads*chunkSize, size-offset); + prims.gather(offset, nelem, chunkSize, chunkSize, -1, 0); + } + } else if (tid < tidEndReduce) { + int group = (3*Proto::MaxGroupWidth) | (1<<16); + // Reduce, broadcast through NVLS + Primitives, /*Direct=*/0, Proto, 0> + prims(tid-tidEndGather, nThreadsReduce, &nvls->down, &nvls->down, args->sendbuff, args->recvbuff, args->redOpArg, group, args); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = gridOffset + (bid*nvls->nHeads+nvls->headRank)*chunkSize; + int nelem = min(chunkSize, size-offset); + prims.recvSend(nelem); + } + } + #endif // NCCL_NVLS_ENABLED + } +}; + template struct RunWorkElement { __device__ __forceinline__ void run(ncclWorkElem *args) { diff --git a/src/collectives/device/alltoall_pivot.h b/src/collectives/device/alltoall_pivot.h index 93dba6d888..443a369cee 100644 --- a/src/collectives/device/alltoall_pivot.h +++ b/src/collectives/device/alltoall_pivot.h @@ -52,7 +52,7 @@ namespace { const T* sendbuff = (const T*)args->sendbuff + send_offset; T* recvbuff = (T *)args->recvbuff + recv_offset; ReduceOrCopyMulti( - tid, nthreads, nullptr, false, 1, &sendbuff, 1, &recvbuff, send_recv_size); + tid, nthreads, 0, nullptr, false, 1, (void **)&sendbuff, 1, (void **)&recvbuff, send_recv_size); } else { for (ssize_t prims_offset = 0; prims_offset < send_recv_size; prims_offset += prims_size) { const int prims_nelem = min(prims_size, send_recv_size - prims_offset); diff --git a/src/collectives/device/common.h b/src/collectives/device/common.h index d82aead3bf..0d9cd3b734 100644 --- a/src/collectives/device/common.h +++ b/src/collectives/device/common.h @@ -29,14 +29,15 @@ #define NCCL_FUNC5(func, algo, devredop, type, nullify) \ MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL, devredop, type)), \ - MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL, devredop, type)), \ + MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL128, devredop, type)), \ MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, SIMPLE, devredop, type)) #define NCCL_FUNC4(func, devredop, type, nullify) \ NCCL_FUNC5(func, TREE, devredop, type, nullify), \ NCCL_FUNC5(func, RING, devredop, type, nullify), \ NCCL_FUNC5(func, COLLNET_DIRECT, devredop, type, nullify), \ - NCCL_FUNC5(func, COLLNET_CHAIN, devredop, type, nullify) + NCCL_FUNC5(func, COLLNET_CHAIN, devredop, type, nullify), \ + NCCL_FUNC5(func, NVLS, devredop, type, nullify) // Must be consistent with ncclDataType_t #define NCCL_FUNCS3A(func, devredop, nullForFloat) \ @@ -113,94 +114,8 @@ static const __device__ constexpr ncclKernelFunc_t ncclFuncs[]{ #endif }; -#define NCCL_FUNC5_LL128(func, algo, devredop, type, nullify) \ - MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL, devredop, type)), \ - MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL128, devredop, type)), \ - MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, SIMPLE, devredop, type)) - -#define NCCL_FUNC4_LL128(func, devredop, type, nullify) \ - NCCL_FUNC5_LL128(func, TREE, devredop, type, nullify), \ - NCCL_FUNC5_LL128(func, RING, devredop, type, nullify), \ - NCCL_FUNC5_LL128(func, COLLNET_DIRECT, devredop, type, nullify), \ - NCCL_FUNC5_LL128(func, COLLNET_CHAIN, devredop, type, nullify) - -// Must be consistent with ncclDataType_t -#define NCCL_FUNCS3A_LL128(func, devredop, nullForFloat) \ - NCCL_FUNC4_LL128(func, devredop, int8_t, 0), \ - NCCL_FUNC4_LL128(func, devredop, uint8_t, 0), \ - NCCL_FUNC4_LL128(func, devredop, int32_t, 0), \ - NCCL_FUNC4_LL128(func, devredop, uint32_t, 0), \ - NCCL_FUNC4_LL128(func, devredop, int64_t, 0), \ - NCCL_FUNC4_LL128(func, devredop, uint64_t, 0), \ - NCCL_FUNC4_LL128(func, devredop, half, nullForFloat), \ - NCCL_FUNC4_LL128(func, devredop, float, nullForFloat), \ - NCCL_FUNC4_LL128(func, devredop, double, nullForFloat), \ - NCCL_FUNC4_LL128(func, devredop, rccl_bfloat16, nullForFloat) -#define NCCL_FUNCS3B_LL128(func, devredop) \ - NCCL_FUNC4_LL128(func, devredop, int8_t, 0), \ - NCCL_FUNC4_LL128(func, devredop, int8_t, 0), \ - NCCL_FUNC4_LL128(func, devredop, int8_t, 0), \ - NCCL_FUNC4_LL128(func, devredop, int8_t, 0), \ - NCCL_FUNC4_LL128(func, devredop, int8_t, 0), \ - NCCL_FUNC4_LL128(func, devredop, int8_t, 0), \ - NCCL_FUNC4_LL128(func, devredop, int8_t, 0), \ - NCCL_FUNC4_LL128(func, devredop, int8_t, 0), \ - NCCL_FUNC4_LL128(func, devredop, int8_t, 0), \ - NCCL_FUNC4_LL128(func, devredop, int8_t, 0) - -// Must be consistent with ncclRedOp_t -#define NCCL_FUNCS2A_LL128(func) \ - NCCL_FUNCS3A_LL128(func, Sum, /*nullForFloat=*/0), \ - NCCL_FUNCS3A_LL128(func, Prod, /*nullForFloat=*/0), \ - NCCL_FUNCS3A_LL128(func, Max, /*nullForFloat=*/0), \ - NCCL_FUNCS3A_LL128(func, Min, /*nullForFloat=*/0), \ - NCCL_FUNCS3A_LL128(func, PreMulSum, /*nullForFloat=*/0), \ - NCCL_FUNCS3A_LL128(func, SumPostDiv, /*nullForFloat=*/1) - -#define NCCL_FUNCS2B_LL128(func) \ - NCCL_FUNCS3B_LL128(func, Sum), \ - NCCL_FUNCS3B_LL128(func, Sum), \ - NCCL_FUNCS3B_LL128(func, Sum), \ - NCCL_FUNCS3B_LL128(func, Sum), \ - NCCL_FUNCS3B_LL128(func, Sum), \ - NCCL_FUNCS3B_LL128(func, Sum) - -// Must be consistent with the ncclFuncSet enum -using ncclKernelFunc_t = void (*)(); - -static const __device__ constexpr ncclKernelFunc_t ncclFuncs_ll128[]{ -// Don't try to initialize the host shadow copy of this device-side global -// variable. There is no host pointer to a device-side function, which -// confuses clang. This will be fixed in the next clang release. -#if defined(__HIP_DEVICE_COMPILE__) -#if defined(BUILD_ALLREDUCE_ONLY) - NCCL_FUNC4_LL128(AllReduce, Sum, float, 0), -#else - NCCL_FUNCS2B_LL128(Broadcast), - NCCL_FUNCS2A_LL128(Reduce), - NCCL_FUNCS2B_LL128(AllGather), - NCCL_FUNCS2A_LL128(ReduceScatter), - NCCL_FUNCS2A_LL128(AllReduce), - NCCL_ONERANK_REDUCE_NAME(PreMulSum, int8_t), - NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint8_t), - NCCL_ONERANK_REDUCE_NAME(PreMulSum, int32_t), - NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint32_t), - NCCL_ONERANK_REDUCE_NAME(PreMulSum, int64_t), - NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint64_t), - NCCL_ONERANK_REDUCE_NAME(PreMulSum, half), - NCCL_ONERANK_REDUCE_NAME(PreMulSum, float), - NCCL_ONERANK_REDUCE_NAME(PreMulSum, double), -#if defined(RCCL_BFLOAT16) - NCCL_ONERANK_REDUCE_NAME(PreMulSum, rccl_bfloat16), -#endif - NCCL_FUNC_NAME(SendRecv, RING, SIMPLE, Sum, int8_t), - NCCL_FUNC_NAME(AllToAllPivot, RING, SIMPLE, Sum, int8_t), -#endif -#endif -}; - -static_assert(FUNC_INDEX_P2P == 3610, "Wrong P2P function index"); -static_assert(FUNC_INDEX_ALLTOALL_PIVOT == 3611, "Wrong AllToAllPivot function index"); +static_assert(FUNC_INDEX_P2P == 4510, "Wrong P2P function index"); +static_assert(FUNC_INDEX_ALLTOALL_PIVOT == 4511, "Wrong AllToAllPivot function index"); #ifndef USE_INDIRECT_FUNCTION_CALL template @@ -217,7 +132,7 @@ struct Caller { template struct Caller{ static __forceinline__ __device__ __host__ - void call(unsigned short funcIndex) noexcept { if (u) ncclFuncs_ll128[f](); else ncclFuncs[f](); } + void call(unsigned short funcIndex) noexcept { ncclFuncs[f](); } }; template @@ -260,46 +175,46 @@ void NCCL_CALL_FUNCTIONS(unsigned short funcIndex) noexcept { else assert("Unsupported function index"); #else - if (funcIndex < 720) { - if (funcIndex % 12 == 0) ncclFunction_Broadcast_TREE_LL_Sum_int8_t(); - else if (USING_LL128 && funcIndex % 12 == 1) ncclFunction_Broadcast_TREE_LL128_Sum_int8_t(); - else if (!USING_LL128 && funcIndex % 12 == 1) ncclFunction_Broadcast_TREE_LL_Sum_int8_t(); - else if (funcIndex % 12 == 2) ncclFunction_Broadcast_TREE_SIMPLE_Sum_int8_t(); - else if (funcIndex % 12 == 3) ncclFunction_Broadcast_RING_LL_Sum_int8_t(); - else if (USING_LL128 && funcIndex % 12 == 4) ncclFunction_Broadcast_RING_LL128_Sum_int8_t(); - else if (!USING_LL128 && funcIndex % 12 == 4) ncclFunction_Broadcast_RING_LL_Sum_int8_t(); - else if (funcIndex % 12 == 5) ncclFunction_Broadcast_RING_SIMPLE_Sum_int8_t(); - else if (funcIndex % 12 == 6) ncclFunction_Broadcast_COLLNET_DIRECT_LL_Sum_int8_t(); - else if (USING_LL128 && funcIndex % 12 == 7) ncclFunction_Broadcast_COLLNET_DIRECT_LL128_Sum_int8_t(); - else if (!USING_LL128 && funcIndex % 12 == 7) ncclFunction_Broadcast_COLLNET_DIRECT_LL_Sum_int8_t(); - else if (funcIndex % 12 == 8) ncclFunction_Broadcast_COLLNET_DIRECT_SIMPLE_Sum_int8_t(); - else if (funcIndex % 12 == 9) ncclFunction_Broadcast_COLLNET_CHAIN_LL_Sum_int8_t(); - else if (USING_LL128 && funcIndex % 12 == 10) ncclFunction_Broadcast_COLLNET_CHAIN_LL128_Sum_int8_t(); - else if (!USING_LL128 && funcIndex % 12 == 10) ncclFunction_Broadcast_COLLNET_CHAIN_LL_Sum_int8_t(); + if (funcIndex < 900) { + if (funcIndex % 15 == 0) ncclFunction_Broadcast_TREE_LL_Sum_int8_t(); + else if (USING_LL128 && funcIndex % 15 == 1) ncclFunction_Broadcast_TREE_LL128_Sum_int8_t(); + else if (!USING_LL128 && funcIndex % 15 == 1) ncclFunction_Broadcast_TREE_LL_Sum_int8_t(); + else if (funcIndex % 15 == 2) ncclFunction_Broadcast_TREE_SIMPLE_Sum_int8_t(); + else if (funcIndex % 15 == 3) ncclFunction_Broadcast_RING_LL_Sum_int8_t(); + else if (USING_LL128 && funcIndex % 15 == 4) ncclFunction_Broadcast_RING_LL128_Sum_int8_t(); + else if (!USING_LL128 && funcIndex % 15 == 4) ncclFunction_Broadcast_RING_LL_Sum_int8_t(); + else if (funcIndex % 15 == 5) ncclFunction_Broadcast_RING_SIMPLE_Sum_int8_t(); + else if (funcIndex % 15 == 6) ncclFunction_Broadcast_COLLNET_DIRECT_LL_Sum_int8_t(); + else if (USING_LL128 && funcIndex % 15 == 7) ncclFunction_Broadcast_COLLNET_DIRECT_LL128_Sum_int8_t(); + else if (!USING_LL128 && funcIndex % 15 == 7) ncclFunction_Broadcast_COLLNET_DIRECT_LL_Sum_int8_t(); + else if (funcIndex % 15 == 8) ncclFunction_Broadcast_COLLNET_DIRECT_SIMPLE_Sum_int8_t(); + else if (funcIndex % 15 == 9) ncclFunction_Broadcast_COLLNET_CHAIN_LL_Sum_int8_t(); + else if (USING_LL128 && funcIndex % 15 == 10) ncclFunction_Broadcast_COLLNET_CHAIN_LL128_Sum_int8_t(); + else if (!USING_LL128 && funcIndex % 15 == 10) ncclFunction_Broadcast_COLLNET_CHAIN_LL_Sum_int8_t(); else ncclFunction_Broadcast_COLLNET_CHAIN_SIMPLE_Sum_int8_t(); } - else if (funcIndex < 1440) Caller<720, 1440, USING_LL128>::call(funcIndex); - else if (funcIndex < 2160) { - if (funcIndex % 12 == 0) ncclFunction_AllGather_TREE_LL_Sum_int8_t(); - else if (USING_LL128 && funcIndex % 12 == 1) ncclFunction_AllGather_TREE_LL128_Sum_int8_t(); - else if (!USING_LL128 && funcIndex % 12 == 1) ncclFunction_AllGather_TREE_LL_Sum_int8_t(); - else if (funcIndex % 12 == 2) ncclFunction_AllGather_TREE_SIMPLE_Sum_int8_t(); - else if (funcIndex % 12 == 3) ncclFunction_AllGather_RING_LL_Sum_int8_t(); - else if (USING_LL128 && funcIndex % 12 == 4) ncclFunction_AllGather_RING_LL128_Sum_int8_t(); - else if (!USING_LL128 && funcIndex % 12 == 4) ncclFunction_AllGather_RING_LL_Sum_int8_t(); - else if (funcIndex % 12 == 5) ncclFunction_AllGather_RING_SIMPLE_Sum_int8_t(); - else if (funcIndex % 12 == 6) ncclFunction_AllGather_COLLNET_DIRECT_LL_Sum_int8_t(); - else if (USING_LL128 && funcIndex % 12 == 7) ncclFunction_AllGather_COLLNET_DIRECT_LL128_Sum_int8_t(); - else if (!USING_LL128 && funcIndex % 12 == 7) ncclFunction_AllGather_COLLNET_DIRECT_LL_Sum_int8_t(); - else if (funcIndex % 12 == 8) ncclFunction_AllGather_COLLNET_DIRECT_SIMPLE_Sum_int8_t(); - else if (funcIndex % 12 == 9) ncclFunction_AllGather_COLLNET_CHAIN_LL_Sum_int8_t(); - else if (USING_LL128 && funcIndex % 12 == 10) ncclFunction_AllGather_COLLNET_CHAIN_LL128_Sum_int8_t(); - else if (!USING_LL128 && funcIndex % 12 == 10) ncclFunction_AllGather_COLLNET_CHAIN_LL_Sum_int8_t(); + else if (funcIndex < 1800) Caller<900, 1800, USING_LL128>::call(funcIndex); + else if (funcIndex < 2700) { + if (funcIndex % 15 == 0) ncclFunction_AllGather_TREE_LL_Sum_int8_t(); + else if (USING_LL128 && funcIndex % 15 == 1) ncclFunction_AllGather_TREE_LL128_Sum_int8_t(); + else if (!USING_LL128 && funcIndex % 15 == 1) ncclFunction_AllGather_TREE_LL_Sum_int8_t(); + else if (funcIndex % 15 == 2) ncclFunction_AllGather_TREE_SIMPLE_Sum_int8_t(); + else if (funcIndex % 15 == 3) ncclFunction_AllGather_RING_LL_Sum_int8_t(); + else if (USING_LL128 && funcIndex % 15 == 4) ncclFunction_AllGather_RING_LL128_Sum_int8_t(); + else if (!USING_LL128 && funcIndex % 15 == 4) ncclFunction_AllGather_RING_LL_Sum_int8_t(); + else if (funcIndex % 15 == 5) ncclFunction_AllGather_RING_SIMPLE_Sum_int8_t(); + else if (funcIndex % 15 == 6) ncclFunction_AllGather_COLLNET_DIRECT_LL_Sum_int8_t(); + else if (USING_LL128 && funcIndex % 15 == 7) ncclFunction_AllGather_COLLNET_DIRECT_LL128_Sum_int8_t(); + else if (!USING_LL128 && funcIndex % 15 == 7) ncclFunction_AllGather_COLLNET_DIRECT_LL_Sum_int8_t(); + else if (funcIndex % 15 == 8) ncclFunction_AllGather_COLLNET_DIRECT_SIMPLE_Sum_int8_t(); + else if (funcIndex % 15 == 9) ncclFunction_AllGather_COLLNET_CHAIN_LL_Sum_int8_t(); + else if (USING_LL128 && funcIndex % 15 == 10) ncclFunction_AllGather_COLLNET_CHAIN_LL128_Sum_int8_t(); + else if (!USING_LL128 && funcIndex % 15 == 10) ncclFunction_AllGather_COLLNET_CHAIN_LL_Sum_int8_t(); else ncclFunction_AllGather_COLLNET_CHAIN_SIMPLE_Sum_int8_t(); } - else if (funcIndex < 3600) Caller<2160, 3600, USING_LL128>::call(funcIndex); + else if (funcIndex < 4500) Caller<2700, 4500, USING_LL128>::call(funcIndex); else { - switch (funcIndex - 3600) { + switch (funcIndex - 4500) { case 0: ncclFunction_OneRankReduce_PreMulSum_int8_t(); break; @@ -479,22 +394,19 @@ class ncclFunction { #define traceData(data2, data4, data8_0, data8_1) #endif - struct ncclShmemGroup { - ncclConnInfo *recvConns[NCCL_MAX_DIRECT_ARITY]; - ncclConnInfo *sendConns[NCCL_MAX_DIRECT_ARITY]; - void* srcs[NCCL_MAX_DIRECT_ARITY+1]; - void* dsts[NCCL_MAX_DIRECT_ARITY+1]; - int totalSendSize[NCCL_MAX_SLICE_PER_CHUNK]; + ncclConnInfo *recvConns[NCCL_MAX_NVLS_ARITY]; + ncclConnInfo *sendConns[NCCL_MAX_NVLS_ARITY]; + void* srcs[NCCL_MAX_NVLS_ARITY+1]; + void* dsts[NCCL_MAX_NVLS_ARITY+1]; + int nvlsRecv; uint64_t barrier; uint64_t barrier_next[NCCL_MAX_GROUPS]; }; struct ncclShmemData { - union { - struct ncclShmemGroup groups[NCCL_MAX_GROUPS]; - }; - uint64_t redOpArgs[NCCL_MAX_DIRECT_ARITY+1]; + struct ncclShmemGroup groups[NCCL_MAX_GROUPS]; + uint64_t redOpArgs[NCCL_MAX_NVLS_ARITY+1]; int channelId; int aborted; alignas(16) struct ncclDevComm comm; @@ -507,6 +419,15 @@ struct ncclShmemData { static_assert(offsetof(struct ncclShmemData, work)%16 == 0, "ncclShmem.work needs to be 16B aligned"); extern __shared__ ncclShmemData ncclShmem; +#if __CUDA_ARCH__ >= 700 + extern __shared__ ulong2 ncclShmemPerWarp[/*ncclShmemDynamicSize()/sizeof(ulong2)*/]; +#else + extern __shared__ ulong2 ncclShmemPerWarp[ncclShmemScratchWarpSize()*(NCCL_MAX_NTHREADS/WARP_SIZE)/sizeof(ulong2)]; +#endif + +__device__ inline void* ncclScratchForWarp(int warp) { + return (char*)ncclShmemPerWarp + warp*ncclShmemScratchWarpSize(); +} #ifdef ENABLE_PROFILING #define __insert_timestamp(line_num) do { \ @@ -578,7 +499,7 @@ static __forceinline__ __device__ void ncclRedopPtrDeref(struct ncclWorkElem* we } } -template +template __forceinline__ __device__ void ncclKernel( struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead ) { @@ -670,10 +591,9 @@ __forceinline__ __device__ void ncclKernel( RunWork().run(&ncclShmem.work); } else { #ifdef USE_INDIRECT_FUNCTION_CALL - if (USING_LL128) ncclFuncs_ll128[ncclShmem.work.header.funcIndex](); - else ncclFuncs[ncclShmem.work.header.funcIndex](); + ncclFuncs[ncclShmem.work.header.funcIndex](); #else - NCCL_CALL_FUNCTIONS(ncclShmem.work.header.funcIndex); + NCCL_CALL_FUNCTIONS<1>(ncclShmem.work.header.funcIndex); #endif } @@ -705,22 +625,12 @@ __forceinline__ __device__ void ncclKernel( #define IMPL_COLL_KERN(func, algo, proto, devredop, type, fIndex) \ __launch_bounds__(NCCL_MAX_NTHREADS, 1) \ __global__ void NCCL_KERN_NAME(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead) { \ - ncclKernel, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, false, false>(comm, channelMask, workHead); \ + ncclKernel, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, false>(comm, channelMask, workHead); \ } \ \ __launch_bounds__(NCCL_MAX_NTHREADS, 1) \ __global__ void NCCL_KERN_NAME_DEBUG(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead) { \ - ncclKernel, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, true, false>(comm, channelMask, workHead); \ -} \ - \ -__launch_bounds__(NCCL_MAX_NTHREADS, 1) \ -__global__ void NCCL_KERN_NAME_LL128(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead) { \ - ncclKernel, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, false, true>(comm, channelMask, workHead); \ -} \ - \ -__launch_bounds__(NCCL_MAX_NTHREADS, 1) \ -__global__ void NCCL_KERN_NAME_LL128_DEBUG(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead) { \ - ncclKernel, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, true, true>(comm, channelMask, workHead); \ + ncclKernel, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, true>(comm, channelMask, workHead); \ } // Examples : AllReduce, RING, LL, Sum, uint8 @@ -748,7 +658,8 @@ __device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, dev IMPL_COLL4(func, TREE, devredop, type, ncclType) \ IMPL_COLL4(func, RING, devredop, type, ncclType) \ IMPL_COLL4(func, COLLNET_DIRECT, devredop, type, ncclType) \ - IMPL_COLL4(func, COLLNET_CHAIN, devredop, type, ncclType) + IMPL_COLL4(func, COLLNET_CHAIN, devredop, type, ncclType) \ + IMPL_COLL4(func, NVLS, devredop, type, ncclType) #define IMPL_COLL2(func, devredop) \ IMPL_COLL3(func, devredop, int8_t, ncclInt8) \ @@ -791,4 +702,6 @@ __device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, dev #define IMPL_COLL_F(func) \ IMPL_COLL_FUNC(func, RING, SIMPLE, Sum, int8_t); +#define NCCL_NVLS_ENABLED (__CUDA_ARCH__ >= 900 && NCCL_NVLS_SUPPORTS(NCCL_TYPE, NCCL_OP)) + #endif diff --git a/src/collectives/device/common_kernel.h b/src/collectives/device/common_kernel.h index b8b370a10c..cd1aa5b6c8 100644 --- a/src/collectives/device/common_kernel.h +++ b/src/collectives/device/common_kernel.h @@ -9,712 +9,373 @@ #define NCCL_COMMON_KERNEL_H_ #include "devcomm.h" +#include "op128.h" +#include "reduce_kernel.h" #include #include #include +#define __syncwarp() + // Define min for ssize_t -static __device__ int min(int a, ssize_t b) { return (a < b) ? a : b; } +inline __device__ int min(int a, ssize_t b) { return (a < b) ? a : b; } inline __device__ int loadInt(int* ptr) { int v; -#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__) v = atomicAdd((unsigned long long *)ptr, 0); -#else - asm volatile("ld.volatile.global.u32 %0, [%1];" - : "=r"(v) : "l"(ptr)); -#endif return v; } -typedef uint64_t PackType; - -template -struct FuncTraits /*{ - __device__ static T preOp(Fn, T); - __device__ static T postOp(Fn, T); -}*/; - -// unpack x and y to elements of type T and apply FUNC to each element -template -struct MULTI { - __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const; - __device__ PackType preOp(FUNC fn, PackType x) const; - __device__ PackType postOp(FUNC fn, PackType x) const; -}; - -template -struct MULTI { - static_assert(sizeof(PackType) == 2 * sizeof(uint32_t), - "PackType must be twice the size of uint32_t."); - union converter { - PackType storage; - struct { - uint32_t a, b; - }; - }; - - __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const { - converter cx, cy, cr; - cx.storage = x; - cy.storage = y; - - // for char, we do these as vector ops - cr.a = fn(cx.a, cy.a); - cr.b = fn(cx.b, cy.b); - - return cr.storage; - } - __device__ PackType preOp(FUNC fn, PackType x) const { - union { - PackType pack; - int8_t elt[8]; - } u; - u.pack = x; - #pragma unroll - for (int i=0; i < 8; i++) - u.elt[i] = FuncTraits().preOp(fn, u.elt[i]); - return u.pack; - } - __device__ PackType postOp(FUNC fn, PackType x) const { - union { - PackType pack; - int8_t elt[8]; - } u; - u.pack = x; - #pragma unroll - for (int i=0; i < 8; i++) - u.elt[i] = FuncTraits().postOp(fn, u.elt[i]); - return u.pack; - } -}; - -template -struct MULTI { - static_assert(sizeof(PackType) == 2 * sizeof(uint32_t), - "PackType must be twice the size of uint32_t."); - union converter { - PackType storage; - struct { - uint32_t a, b; - }; - }; - - __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const { - converter cx, cy, cr; - cx.storage = x; - cy.storage = y; - - // for char, we do these as vector ops - cr.a = fn(cx.a, cy.a); - cr.b = fn(cx.b, cy.b); - - return cr.storage; - } - __device__ PackType preOp(FUNC fn, PackType x) const { - union { - PackType pack; - uint8_t elt[8]; - } u; - u.pack = x; - #pragma unroll - for (int i=0; i < 8; i++) - u.elt[i] = FuncTraits().preOp(fn, u.elt[i]); - return u.pack; - } - __device__ PackType postOp(FUNC fn, PackType x) const { - union { - PackType pack; - uint8_t elt[8]; - } u; - u.pack = x; - #pragma unroll - for (int i=0; i < 8; i++) - u.elt[i] = FuncTraits().postOp(fn, u.elt[i]); - return u.pack; - } -}; - -template -struct MULTI { - static_assert(sizeof(PackType) == 2 * sizeof(int32_t), - "PackType must be twice the size of int."); - union converter { - PackType storage; - struct { - int32_t a, b; - }; - }; - - __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const { - converter cx, cy, cr; - cx.storage = x; - cy.storage = y; - - cr.a = fn(cx.a, cy.a); - cr.b = fn(cx.b, cy.b); - - return cr.storage; - } - __device__ PackType preOp(FUNC fn, PackType x) const { - union { - PackType pack; - int32_t elt[2]; - } u; - u.pack = x; - u.elt[0] = FuncTraits().preOp(fn, u.elt[0]); - u.elt[1] = FuncTraits().preOp(fn, u.elt[1]); - return u.pack; - } - __device__ PackType postOp(FUNC fn, PackType x) const { - union { - PackType pack; - int32_t elt[2]; - } u; - u.pack = x; - u.elt[0] = FuncTraits().postOp(fn, u.elt[0]); - u.elt[1] = FuncTraits().postOp(fn, u.elt[1]); - return u.pack; - } -}; - -template -struct MULTI { - static_assert(sizeof(PackType) == 2 * sizeof(uint32_t), - "PackType must be twice the size of int."); - union converter { - PackType storage; - struct { - uint32_t a, b; - }; - }; - - __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const { - converter cx, cy, cr; - cx.storage = x; - cy.storage = y; - - cr.a = fn(cx.a, cy.a); - cr.b = fn(cx.b, cy.b); - - return cr.storage; - } - __device__ PackType preOp(FUNC fn, PackType x) const { - union { - PackType pack; - uint32_t elt[2]; - } u; - u.pack = x; - u.elt[0] = FuncTraits().preOp(fn, u.elt[0]); - u.elt[1] = FuncTraits().preOp(fn, u.elt[1]); - return u.pack; - } - __device__ PackType postOp(FUNC fn, PackType x) const { - union { - PackType pack; - uint32_t elt[2]; - } u; - u.pack = x; - u.elt[0] = FuncTraits().postOp(fn, u.elt[0]); - u.elt[1] = FuncTraits().postOp(fn, u.elt[1]); - return u.pack; - } -}; - -template -struct MULTI { - static_assert(sizeof(PackType) == 4 * sizeof(half), - "PackType must be four times the size of half."); - - union Converter { - PackType pack; - half2 h2[2]; - }; - __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const { - Converter cx, cy, cr; - cx.pack = x; - cy.pack = y; - cr.h2[0] = fn(cx.h2[0], cy.h2[0]); - cr.h2[1] = fn(cx.h2[1], cy.h2[1]); - return cr.pack; - } - __device__ PackType preOp(FUNC fn, PackType x) const { - Converter c; - c.pack = x; - c.h2[0] = FuncTraits().preOp(fn, c.h2[0]); - c.h2[1] = FuncTraits().preOp(fn, c.h2[1]); - return c.pack; - } - __device__ PackType postOp(FUNC fn, PackType x) const { - Converter c; - c.pack = x; - c.h2[0] = FuncTraits().postOp(fn, c.h2[0]); - c.h2[1] = FuncTraits().postOp(fn, c.h2[1]); - return c.pack; - } -}; - -#if defined(RCCL_BFLOAT16) -template -struct MULTI { - static_assert(sizeof(PackType) == 4 * sizeof(rccl_bfloat16), - "PackType must be four times the size of rccl_bfloat16."); - - union Converter { - PackType pack; - rccl_bfloat16 h2[4]; - }; - __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const { - Converter cx, cy, cr; - cx.pack = x; - cy.pack = y; - cr.h2[0] = fn(cx.h2[0], cy.h2[0]); - cr.h2[1] = fn(cx.h2[1], cy.h2[1]); - cr.h2[2] = fn(cx.h2[2], cy.h2[2]); - cr.h2[3] = fn(cx.h2[3], cy.h2[3]); - return cr.pack; - } - __device__ PackType preOp(FUNC fn, PackType x) const { - Converter c; - c.pack = x; - c.h2[0] = FuncTraits().preOp(fn, c.h2[0]); - c.h2[1] = FuncTraits().preOp(fn, c.h2[1]); - c.h2[2] = FuncTraits().preOp(fn, c.h2[2]); - c.h2[3] = FuncTraits().preOp(fn, c.h2[3]); - return c.pack; - } - __device__ PackType postOp(FUNC fn, PackType x) const { - Converter c; - c.pack = x; - c.h2[0] = FuncTraits().postOp(fn, c.h2[0]); - c.h2[1] = FuncTraits().postOp(fn, c.h2[1]); - c.h2[2] = FuncTraits().postOp(fn, c.h2[2]); - c.h2[3] = FuncTraits().postOp(fn, c.h2[3]); - return c.pack; - } -}; -#endif - -template -struct MULTI { - static_assert(sizeof(PackType) == 2 * sizeof(float), - "PackType must be twice the size of float."); - union converter { - PackType storage; - struct { - float a, b; - }; - }; - - __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const { - converter cx, cy, cr; - cx.storage = x; - cy.storage = y; - - cr.a = fn(cx.a, cy.a); - cr.b = fn(cx.b, cy.b); - - return cr.storage; - } - __device__ PackType preOp(FUNC fn, PackType x) const { - union { - PackType pack; - float elt[2]; - } u; - u.pack = x; - u.elt[0] = FuncTraits().preOp(fn, u.elt[0]); - u.elt[1] = FuncTraits().preOp(fn, u.elt[1]); - return u.pack; - } - __device__ PackType postOp(FUNC fn, PackType x) const { - union { - PackType pack; - float elt[2]; - } u; - u.pack = x; - u.elt[0] = FuncTraits().postOp(fn, u.elt[0]); - u.elt[1] = FuncTraits().postOp(fn, u.elt[1]); - return u.pack; - } -}; - -template -struct MULTI { - static_assert(sizeof(PackType) == sizeof(double), - "PackType must be the same size as double."); - __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const { - double rv = fn(__longlong_as_double(x), __longlong_as_double(y)); - return __double_as_longlong(rv); - } - __device__ PackType preOp(FUNC fn, PackType x) const { - union { - PackType pack; - double elt; - } u; - u.pack = x; - u.elt = FuncTraits().preOp(fn, u.elt); - return u.pack; - } - __device__ PackType postOp(FUNC fn, PackType x) const { - union { - PackType pack; - double elt; - } u; - u.pack = x; - u.elt = FuncTraits().postOp(fn, u.elt); - return u.pack; - } -}; - -template -struct MULTI { - static_assert(sizeof(PackType) == sizeof(uint64_t), - "PackType must be the same size as uint64_t."); - __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const { - uint64_t rv = fn(x, y); - return rv; - } - __device__ PackType preOp(FUNC fn, PackType x) const { - union { - PackType pack; - uint64_t elt; - } u; - u.pack = x; - u.elt = FuncTraits().preOp(fn, u.elt); - return u.pack; - } - __device__ PackType postOp(FUNC fn, PackType x) const { - union { - PackType pack; - uint64_t elt; - } u; - u.pack = x; - u.elt = FuncTraits().postOp(fn, u.elt); - return u.pack; - } -}; - -template -struct MULTI { - static_assert(sizeof(PackType) == sizeof(int64_t), - "PackType must be the same size as int64_t."); - __device__ PackType operator()(FUNC fn, const PackType x, const PackType y) const { - int64_t rv = fn((int64_t)x, (int64_t)y); - return rv; - } - __device__ PackType preOp(FUNC fn, PackType x) const { - union { - PackType pack; - int64_t elt; - } u; - u.pack = x; - u.elt = FuncTraits().preOp(fn, u.elt); - return u.pack; - } - __device__ PackType postOp(FUNC fn, PackType x) const { - union { - PackType pack; - int64_t elt; - } u; - u.pack = x; - u.elt = FuncTraits().postOp(fn, u.elt); - return u.pack; - } -}; - -template inline __device__ -T vFetch(const volatile T* ptr) { - return __builtin_nontemporal_load(ptr); -} - -template inline __device__ -void vStore(volatile T* ptr, const T val) { - __builtin_nontemporal_store(val, ptr); -} - -#if CUDART_VERSION < 9000 && !(defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)) -template<> inline __device__ -half vFetch(const volatile half* ptr) { - half r; - r.x = ptr->x; - return r; -} - -template<> inline __device__ -void vStore(volatile half* ptr, const half val) { - ptr->x = val.x; -} -#else -template<> inline __device__ -half vFetch(const volatile half* ptr) { - half h; - uint8_t *pr = (uint8_t *)&h; - pr[0] = __builtin_nontemporal_load((uint8_t*)ptr); - pr[1] = __builtin_nontemporal_load((uint8_t*)ptr+1); - return h; -} - -template<> inline __device__ -void vStore(volatile half* ptr, const half val) { - uint8_t *pr = (uint8_t *)&val; - __builtin_nontemporal_store(pr[0], (uint8_t*)ptr); - __builtin_nontemporal_store(pr[1], ((uint8_t*)ptr)+1); -} - -template<> inline __device__ -rccl_bfloat16 vFetch(const volatile rccl_bfloat16* ptr) { - rccl_bfloat16 r; - uint8_t *pr = (uint8_t *)&r.data; - pr[0] = __builtin_nontemporal_load((uint8_t*)&ptr->data); - pr[1] = __builtin_nontemporal_load(((uint8_t*)&ptr->data)+1); - return r; -} - -template<> inline __device__ -void vStore(volatile rccl_bfloat16* ptr, const rccl_bfloat16 val) { - uint8_t *pr = (uint8_t *)&val.data; - __builtin_nontemporal_store(pr[0], (uint8_t*)&ptr->data); - __builtin_nontemporal_store(pr[1], ((uint8_t*)&ptr->data)+1); -} -#endif - -typedef ulong2 Pack128; - -template -struct MULTI128 { - __device__ void operator()(FUNC fn, Pack128& x, Pack128 const& y) const { - x.x = MULTI()(fn, x.x, y.x); - x.y = MULTI()(fn, x.y, y.y); - } - __device__ void preOp(FUNC fn, Pack128 &x) const { - x.x = MULTI().preOp(fn, x.x); - x.y = MULTI().preOp(fn, x.y); - } - __device__ void postOp(FUNC fn, Pack128 &x) const { - x.x = MULTI().postOp(fn, x.x); - x.y = MULTI().postOp(fn, x.y); - } -}; - -inline __device__ void Fetch128(Pack128& v, const Pack128* p) { -#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__) - v.x = __builtin_nontemporal_load(&p->x); - v.y = __builtin_nontemporal_load(&p->y); -#else - asm volatile("ld.volatile.global.v2.u64 {%0,%1}, [%2];" : "=l"(v.x), "=l"(v.y) : "l"(p) : "memory"); -#endif -} -inline __device__ void Store128(Pack128* p, Pack128& v) { -#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__) - __builtin_nontemporal_store(v.x, &p->x); - __builtin_nontemporal_store(v.y, &p->y); -#else - asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" :: "l"(p), "l"(v.x), "l"(v.y) : "memory"); -#endif -} - -template -__device__ __forceinline__ void ReduceCopyMulti(const int w, const int nw, const int t, - uint64_t* redOpArgs, bool postOp, int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const Int Nelem +template +__device__ __forceinline__ void reduceCopyPacks( + int nThreads, int &thread, + uint64_t redArg, uint64_t *preOpArgs, bool postOp, + int nSrcs, void **srcPtrs, int nDsts, void **dstPtrs, + IntBytes &nBytesBehind, IntBytes &nBytesAhead ) { - const Int inc = nw * UNROLL * WARP_SIZE; - Int offset = w * UNROLL * WARP_SIZE + t; + static_assert(std::is_signed::value, "IntBytes must be a signed integral type."); - const T* srcs[MAXSRCS]; - for (int i=0; i().preOp(fn, vals[u]); - } + // This thread's initial position. + IntBytes threadBytesBehind = nBytesBehind + (warp*BytePerHunk + lane*BytePerPack); + IntBytes threadBytesAhead = nBytesAhead - (warp*BytePerHunk + lane*BytePerPack); + // Number of hunks to be consumed over all warps. + IntBytes nHunksAhead = nBytesAhead/BytePerHunk; + // Advance collective position. + nBytesBehind += nHunksAhead*BytePerHunk; + nBytesAhead -= nHunksAhead*BytePerHunk; + if (Unroll==1 && BytePerPack <= nBytesAhead) { + // Only Unroll=1 can do partial hunks (where not all threads partake). + nHunksAhead += 1; + nBytesBehind += nBytesAhead - (nBytesAhead%BytePerPack); + nBytesAhead = nBytesAhead%BytePerPack; + } + nHunksAhead -= warp; - #pragma unroll - for (int i=1; i().preOp(fn, vals2[u]); + RedFn redFn(redArg); + uintptr_t minSrcs[MinSrcs + !MinSrcs]; + uintptr_t minDsts[MinDsts + !MinDsts]; + #pragma unroll + for (int s=0; s < MinSrcs; s++) + minSrcs[s] = cvta_to_global(srcPtrs[s]) + threadBytesBehind; + #pragma unroll + for (int d=0; d < MinDsts; d++) + minDsts[d] = cvta_to_global(dstPtrs[d]) + threadBytesBehind; + + // We dictate loop termination condition according to whether partial hunks + // can be handled or not. + while (Unroll==1 ? (BytePerPack <= threadBytesAhead) : (0 < nHunksAhead)) { + BytePack acc[Unroll]; + + { RedFn preFn(0 < PreOpSrcs ? preOpArgs[0] : 0); + #pragma unroll Unroll + for (int u=0; u < Unroll; u++) { + // Use volatile loads in case credits are polled for with volatile (instead of acquire). + acc[u] = ld_volatile_global(minSrcs[0]); + minSrcs[0] += WARP_SIZE*BytePerPack; + if (0 < PreOpSrcs) acc[u] = applyPreOp(preFn, acc[u]); } - for (int u = 0; u < UNROLL; ++u) vals[u] = fn(vals[u], vals2[u]); } - #pragma unroll - for (int i=MINSRCS; i().preOp(fn, vals2[u]); - } - for (int u = 0; u < UNROLL; ++u) vals[u] = fn(vals[u], vals2[u]); + + #pragma unroll Unroll + for (int s=1; s < MinSrcs; s++) { + BytePack tmp[Unroll]; + RedFn preFn(s < PreOpSrcs ? preOpArgs[s] : 0); + #pragma unroll Unroll + for (int u=0; u < Unroll; u++) { + // Use volatile loads in case credits are polled for with volatile (instead of acquire). + tmp[u] = ld_volatile_global(minSrcs[s]); + minSrcs[s] += WARP_SIZE*BytePerPack; + } + #pragma unroll Unroll + for (int u=0; u < Unroll; u++) { + if (s < PreOpSrcs) tmp[u] = applyPreOp(preFn, tmp[u]); + acc[u] = applyReduce(redFn, acc[u], tmp[u]); + } + } + + for (int s=MinSrcs; (MinSrcs < MaxSrcs) && (s < MaxSrcs) && (s < nSrcs); s++) { + uintptr_t src = cvta_to_global(srcPtrs[s]) + threadBytesBehind; + BytePack tmp[Unroll]; + RedFn preFn(s < PreOpSrcs ? preOpArgs[s] : 0); + #pragma unroll Unroll + for (int u=0; u < Unroll; u++) { + // Use volatile loads in case credits are polled for with volatile (instead of acquire). + tmp[u] = ld_volatile_global(src); + src += WARP_SIZE*BytePerPack; + } + #pragma unroll Unroll + for (int u=0; u < Unroll; u++) { + if (s < PreOpSrcs) tmp[u] = applyPreOp(preFn, tmp[u]); + acc[u] = applyReduce(redFn, acc[u], tmp[u]); } } if (postOp) { - FUNC fn(redOpArgs[0]); - #pragma unroll - for (int u = 0; u < UNROLL; ++u) vals[u] = FuncTraits().postOp(fn, vals[u]); + #pragma unroll Unroll + for (int u=0; u < Unroll; u++) + acc[u] = applyPostOp(redFn, acc[u]); } - // Store - #pragma unroll - for (int i = 0; i < MINDSTS; i++) { - for (int u = 0; u < UNROLL; ++u) vStore(dsts[i]+u*WARP_SIZE, vals[u]); - } - #pragma unroll - for (int i=MINDSTS; i(minDsts[d], acc[u]); + minDsts[d] += WARP_SIZE*BytePerPack; } } - for (int i=0; i(dst, acc[u]); + dst += WARP_SIZE*BytePerPack; + } + } + + nWarps = nThreads/WARP_SIZE; + #pragma unroll + for (int s=0; s < MinSrcs; s++) minSrcs[s] += (nWarps-1)*BytePerHunk; + #pragma unroll + for (int d=0; d < MinDsts; d++) minDsts[d] += (nWarps-1)*BytePerHunk; + threadBytesBehind += nWarps*BytePerHunk; + threadBytesAhead -= nWarps*BytePerHunk; + nHunksAhead -= nWarps; } + + nWarps = nThreads/WARP_SIZE; + warp = thread/WARP_SIZE; + lane = thread%WARP_SIZE; + // The last loop iteration could have been partial, i.e. not taken by all + // threads. The threads that weren't included need an extra subtraction to + // make the value warp uniform. + if (Unroll==1 && nHunksAhead > 0) nHunksAhead -= nWarps; + // Rotate warps so the warp which got the least work here will be warp 0. + // This effectively assigns: warp = (warp-nHunks+nWarps)%nWarps; + warp = -nHunksAhead; + thread = warp*WARP_SIZE + lane; } -template -__device__ __forceinline__ void ReduceCopy128bMulti(const int w, const int nw, const int t, - uint64_t* redOpArgs, bool postOp, int nsrcs, const T** s, int ndsts, T** d, const int elemOffset, const Int Npack - ) { - const Int inc = nw * UNROLL * WARP_SIZE; - Int offset = w * UNROLL * WARP_SIZE + t; - - const Pack128* srcs[MAXSRCS]; - for (int i=0; i().preOp(fn, vals[u]); - } - - #pragma unroll - for (int i=1; i().preOp(fn, vals2[u]); - } - for (int u = 0; u < UNROLL; ++u) MULTI128()(fn, vals[u], vals2[u]); - } - #pragma unroll - for (int i=MINSRCS; i().preOp(fn, vals2[u]); - } - for (int u = 0; u < UNROLL; ++u) MULTI128()(fn, vals[u], vals2[u]); - } - } - - if (postOp) { - FUNC fn(redOpArgs[0]); - #pragma unroll - for (int u = 0; u < UNROLL; ++u) MULTI128().postOp(fn, vals[u]); - } - - // Store - #pragma unroll - for (int i = 0; i < MINDSTS; i++) { - for (int u = 0; u < UNROLL; ++u) Store128(dsts[i]+u*WARP_SIZE, vals[u]); - } - #pragma unroll - for (int i=MINDSTS; i -__device__ int ptrAlign128(T* ptr) { return (uint64_t)ptr % alignof(uint32_t); } - -#define PACKELEMS (sizeof(Pack128) / sizeof(T)) -#define AUTOUNROLL (UNROLL*((MINSRCS==1 && MINDSTS==1) ? 2 : 1)) - -template +template __device__ __forceinline__ void ReduceOrCopyMulti( - const int tid, const int nthreads, uint64_t* redOpArgs, bool postOp, int nsrcs, const T** srcs, int ndsts, T** dsts, Int N + int thread, int nThreads, + uint64_t redArg, uint64_t *preOpArgs, bool postOp, + int nSrcs, void **srcPtrs, int nDsts, void **dstPtrs, + IntBytes nElts ) { - Int Nrem = N; - if (Nrem <= 0) return; - - int w = tid / WARP_SIZE; // Warp number - int nw = nthreads / WARP_SIZE; // Number of warps - int t = tid % WARP_SIZE; // Thread (inside the warp) + //int nWarps = nThreads/WARP_SIZE; + //int warp = thread/WARP_SIZE; + int lane = thread%WARP_SIZE; // Check that all is 16B aligned. If not don't use 16B load/stores. - int align = 0; - #pragma unroll - for (int i=0; i + (nThreads, /*&*/thread, redArg, preOpArgs, postOp, + nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead); + if (nBytesAhead == 0) return; - // main loop - Int Npack = (Nrem / (PACKELEMS*AUTOUNROLL*WARP_SIZE)) * (AUTOUNROLL*WARP_SIZE); // round down - Int Nelem = Npack * PACKELEMS; - - ReduceCopy128bMulti - (w, nw, t, redOpArgs, postOp, nsrcs, srcs, ndsts, dsts, offset, Npack); - - Nrem -= Nelem; - if (Nrem == 0) return; - offset += Nelem; - - // slightly less optimized for section when we don't have full unrolling - Npack = Nrem / PACKELEMS; - Nelem = Npack * PACKELEMS; - - ReduceCopy128bMulti - (w, nw, t, redOpArgs, postOp, nsrcs, srcs, ndsts, dsts, offset, Npack); - - Nrem -= Nelem; - if (Nrem == 0) return; - offset += Nelem; + reduceCopyPacks + (nThreads, /*&*/thread, redArg, preOpArgs, postOp, + nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead); + if (nBytesAhead == 0) return; } - // unrolled, by-type (mostly for unaligned buffers) - Int Nelem = (Nrem / (AUTOUNROLL*PACKELEMS/2*WARP_SIZE)) * (AUTOUNROLL*PACKELEMS/2*WARP_SIZE); // round down + reduceCopyPacks + (nThreads, /*&*/thread, redArg, preOpArgs, postOp, + nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead); + if (nBytesAhead == 0) return; - ReduceCopyMulti - (w, nw, t, redOpArgs, postOp, nsrcs, srcs, ndsts, dsts, offset, Nelem); - - Nrem -= Nelem; - if (Nrem == 0) return; - offset += Nelem; - - // no unroll, by type. Should finish what's remaining. - ReduceCopyMulti - (w, nw, t, redOpArgs, postOp, nsrcs, srcs, ndsts, dsts, offset, Nrem); + reduceCopyPacks + (nThreads, /*&*/thread, redArg, preOpArgs, postOp, + nSrcs, srcPtrs, nDsts, dstPtrs, /*&*/nBytesBehind, /*&*/nBytesAhead); } +// Copies from srcAddr to dstAddr using multimem load/store. The amount copied +// will be at most Unroll*BytePerPack*WARP_SIZE. If Partial=1, then the amount +// will be the min() of that and nBytesAhead. If srcAddr is not BytePerPack +// aligned then the amount copied will be less by (srcAddr%BytePerPack) since +// we begin loads at the first pack containing the first element. +template +__device__ __forceinline__ void copyMultimemMultimem_WarpUnrolled( + int lane, RedFn redFn, bool postOp, uintptr_t srcAddr, uintptr_t dstAddr, + IntBytes nBytesAhead, uint32_t scratchAddr + ) { +#if 0 + int srcMisalign = SrcAligned ? 0 : srcAddr%BytePerPack; + srcAddr -= srcMisalign; + + BytePack reg[Unroll]; + int offset = lane*BytePerPack; + #pragma unroll Unroll + for (int u=0; u < Unroll; u++) { + if (!Partial || (offset < srcMisalign + nBytesAhead)) { + reg[u] = applyLoadMultimem(redFn, srcAddr+offset); + if (postOp) reg[u] = applyPostOp(redFn, reg[u]); + } + offset += WARP_SIZE*BytePerPack; + } + + if (SrcAligned && DstAligned) { + offset = lane*BytePerPack; + #pragma unroll Unroll + for (int u=0; u < Unroll; u++) { + if (!Partial || offset < nBytesAhead) { + multimem_st_global(dstAddr+offset, reg[u]); + } + offset += WARP_SIZE*BytePerPack; + } + } else { + __syncwarp(); + offset = lane*BytePerPack; + #pragma unroll Unroll + for (int u=0; u < Unroll; u++) { + if (!Partial || (offset < srcMisalign + nBytesAhead)) { + st_shared(scratchAddr+offset, reg[u]); + } + offset += WARP_SIZE*BytePerPack; + } + __syncwarp(); + if (!SrcAligned) { + // Ignore the beginning of the first pack corresponding to bytes overread + // due to misalignment. + nBytesAhead = min(nBytesAhead, Unroll*WARP_SIZE*BytePerPack - srcMisalign); + } + copyGlobalShared_WarpUnrolled + + (lane, dstAddr, scratchAddr+srcMisalign, nBytesAhead); + } +#endif +} + +// copyMultimemMultimem_IfEnabled has two overloads: the enabled case whose first arg +// has type `std::true_type` and the disabled case with first arg `std::false_type`. +// This is to guard the template instantiations of Apply_LoadMultimem on types/ops where +// they aren't supported. A nicer approach is to use C++17's "if constexpr". +template +__device__ __forceinline__ void copyMultimemMultimem_IfEnabled( + std::false_type enabled/*=false*/, + int thread, int nThreads, uint64_t redArg, bool postOp, + void *srcPtr, void *dstPtr, IntBytes nElts, uint32_t warpScratchAddr + ) { + // nop +} + +template +__device__ __forceinline__ void copyMultimemMultimem_IfEnabled( + std::true_type enabled/*=true*/, + int thread, int nThreads, uint64_t redArg, bool postOp, + void *srcPtr, void *dstPtr, IntBytes nElts, uint32_t warpScratchAddr + ) { + static_assert(std::is_signed::value, "IntBytes must be a signed integral type."); + + constexpr int BytePerPack = Apply_LoadMultimem::PackSize; + using T = typename RedFn::EltType; + constexpr int Unroll = ncclNvlsUnroll(BytePerPack); + constexpr int BytePerHunk = Unroll*WARP_SIZE*BytePerPack; + int nWarps = nThreads/WARP_SIZE; + int warp = thread/WARP_SIZE; + int lane = thread%WARP_SIZE; + RedFn redFn(redArg); + + uintptr_t srcAddr = cvta_to_global(srcPtr); + uintptr_t dstAddr = cvta_to_global(dstPtr); + IntBytes warpBytesAhead = nElts*sizeof(T); + bool partialHunkIsFront; + + // First handle misalignment of srcAddr. + if ((BytePerPack != sizeof(T)) && (srcAddr%BytePerPack != 0)) { + // If srcAddr isn't pack aligned then the first hunk processed will be short + // the same number of bytes as srcAddr's misalignment. + if (warp == 0) { + partialHunkIsFront = true; + goto PartialHunk; // "call" PartialHunk() + PartialHunkFrontReturn: + warp = nWarps; + } + warp -= 1; // Rotate warp numbers for load balancing + int advanced = BytePerHunk-(srcAddr%BytePerPack); // since copyMultimemMultimem_WarpUnrolled shorts by the misalignment + srcAddr += advanced; // srcAddr is now pack aligned + dstAddr += advanced; + warpBytesAhead -= advanced; + } + + warpBytesAhead -= warp*BytePerHunk; + srcAddr += warp*BytePerHunk; + dstAddr += warp*BytePerHunk; + // Now that srcAddr is pack aligned detect if dstAddr is pack aligned. + if ((BytePerPack == sizeof(T)) || (dstAddr%BytePerPack == 0)) { + while (BytePerHunk <= warpBytesAhead) { + copyMultimemMultimem_WarpUnrolled + + (lane, redFn, postOp, srcAddr, dstAddr, warpBytesAhead, warpScratchAddr); + srcAddr += nWarps*BytePerHunk; + dstAddr += nWarps*BytePerHunk; + warpBytesAhead -= nWarps*BytePerHunk; + } + } else { + while (BytePerHunk <= warpBytesAhead) { + copyMultimemMultimem_WarpUnrolled + + (lane, redFn, postOp, srcAddr, dstAddr, warpBytesAhead, warpScratchAddr); + srcAddr += nWarps*BytePerHunk; + dstAddr += nWarps*BytePerHunk; + warpBytesAhead -= nWarps*BytePerHunk; + } + } + + if (0 < warpBytesAhead) { + partialHunkIsFront = false; + goto PartialHunk; // "call" PartialHunk() + PartialHunkBackReturn:; + } + return; + +PartialHunk: + // We have to handle a partial hunk possibly at the front and back of the + // buffer. We generate the code once here since its a lot of instructions, + // and then simulate function calls with gotos. + copyMultimemMultimem_WarpUnrolled + + (lane, redFn, postOp, srcAddr, dstAddr, warpBytesAhead, warpScratchAddr); + if (partialHunkIsFront) goto PartialHunkFrontReturn; + goto PartialHunkBackReturn; +} + +template +__device__ __forceinline__ void copyMultimemMultimem( + int thread, int nThreads, uint64_t redArg, bool postOp, + void *srcPtr, void *dstPtr, IntBytes nElts, uint32_t warpScratchAddr + ) { + constexpr bool Enabled = Apply_LoadMultimem::PackSize != 0; + copyMultimemMultimem_IfEnabled( + /*enabled=*/std::integral_constant(), + thread, nThreads, redArg, postOp, srcPtr, dstPtr, nElts, warpScratchAddr); +} #endif // COMMON_KERNEL_H_ diff --git a/src/collectives/device/functions.cu b/src/collectives/device/functions.cu index 04179da5c2..018e511c0b 100644 --- a/src/collectives/device/functions.cu +++ b/src/collectives/device/functions.cu @@ -10,6 +10,9 @@ #include "common.h" __shared__ ncclShmemData ncclShmem; +#if __CUDA_ARCH__ < 700 + __shared__ ulong2 ncclShmemPerWarp[ncclShmemScratchWarpSize()*(NCCL_MAX_NTHREADS/WARP_SIZE)/sizeof(ulong2)]; +#endif #if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__) #else @@ -22,7 +25,8 @@ __shared__ ncclShmemData ncclShmem; NCCL_FUNC5(func, TREE, devredop, type, nullify), \ NCCL_FUNC5(func, RING, devredop, type, nullify), \ NCCL_FUNC5(func, COLLNET_DIRECT, devredop, type, nullify), \ - NCCL_FUNC5(func, COLLNET_CHAIN, devredop, type, nullify) + NCCL_FUNC5(func, COLLNET_CHAIN, devredop, type, nullify), \ + NCCL_FUNC5(func, NVLS, devredop, type, nullify) #if defined(__CUDA_BF16_TYPES_EXIST__) // Must be consistent with ncclDataType_t diff --git a/src/collectives/device/msccl_kernel.cu b/src/collectives/device/msccl_kernel.cu index 0948d193f3..7a78865983 100644 --- a/src/collectives/device/msccl_kernel.cu +++ b/src/collectives/device/msccl_kernel.cu @@ -114,7 +114,7 @@ __device__ __forceinline__ static void threadBlockCopy( for (int r = 0; r < numloops; r++) { \ srcOffset = srcBaseOffset + (ssize_t)mscclShmem.mscclTB.reductionSrcOffsets[t->reductionPointer+r] * sizePerMscclChunk; \ reduceInput = load(srcPointer + srcOffset); \ - o = redFn(reduceInput, o); \ + o = applyReduce(redFn, reduceInput, o); \ } #define MSCCL_REDUCE_UNROLL_LOOP_B(numloops) \ diff --git a/src/collectives/device/onerank_reduce.cu b/src/collectives/device/onerank_reduce.cu index 569c4c7a0f..9ebe8eea84 100644 --- a/src/collectives/device/onerank_reduce.cu +++ b/src/collectives/device/onerank_reduce.cu @@ -7,7 +7,7 @@ #include "devcomm.h" #include "collectives.h" -#include "reduce_kernel.h" +#include "common_kernel.h" #include "common.h" namespace { @@ -40,8 +40,10 @@ namespace { i1 = i1 < eltN ? i1 : eltN; src += i0; dst += i0; - ReduceOrCopyMulti - (tid, tn, &(we->redOpArg), true, 1, &src, 1, &dst, i1-i0); + void *vsrc = (void*)src; + void *vdst = (void*)dst; + ReduceOrCopyMulti + (tid, tn, we->redOpArg, &(we->redOpArg), true, 1, &vsrc, 1, &vdst, i1-i0); } } } diff --git a/src/collectives/device/op128.h b/src/collectives/device/op128.h index 46fc8df501..8ee13fbb6f 100644 --- a/src/collectives/device/op128.h +++ b/src/collectives/device/op128.h @@ -8,29 +8,27 @@ #define OP128_H_ inline __device__ void load128(const uint64_t* ptr, uint64_t &v0, uint64_t &v1) { - asm volatile("ld.volatile.global.v2.u64 {%0,%1}, [%2];" - : "=l"(v0), "=l"(v1) : "l"(ptr)); + v0 = __builtin_nontemporal_load(ptr); + v1 = __builtin_nontemporal_load(ptr+1); } inline __device__ void store128(uint64_t* ptr, uint64_t v0, uint64_t v1) { - asm volatile("st.volatile.global.v2.u64 [%2], {%0,%1};" - :: "l"(v0), "l"(v1), "l"(ptr)); + __builtin_nontemporal_store(v0, ptr); + __builtin_nontemporal_store(v1, ptr+1); } inline __device__ uint64_t* shmemCvtPtr(volatile uint64_t* shmemGenericPtr) { - uint64_t* shmemAsmPtr; - asm volatile("cvta.to.shared.u64 %0, %1;" : "=l"(shmemAsmPtr) : "l"(shmemGenericPtr)); - return shmemAsmPtr; + return (uint64_t*)shmemGenericPtr; } inline __device__ void loadShmem128(uint64_t* shmemAsmPtr, uint64_t &v0, uint64_t &v1) { - asm volatile("ld.volatile.shared.v2.u64 {%0,%1}, [%2];" - : "=l"(v0), "=l"(v1) : "l"(shmemAsmPtr)); + v0 = *(shmemAsmPtr); + v1 = *(shmemAsmPtr+1); } inline __device__ void storeShmem128(uint64_t* shmemAsmPtr, uint64_t v0, uint64_t v1) { - asm volatile("st.volatile.shared.v2.u64 [%2], {%0,%1};" - :: "l"(v0), "l"(v1), "l"(shmemAsmPtr)); + *(shmemAsmPtr) = v0; + *(shmemAsmPtr+1) = v1; } template @@ -46,23 +44,300 @@ inline __device__ void loadShmemMisaligned128(T *ptr, uint64_t &v0, uint64_t &v1 // Produce 4 bytes of sub-register type by reading 2 4-byte // aligned values and shifting. uint32_t lo, hi; - asm("ld.shared.b32 %0,[%1];" : "=r"(lo) : "l"(ptr4+e+0)); - asm("ld.shared.b32 %0,[%1];" : "=r"(hi) : "l"(ptr4+e+1)); + lo = __builtin_nontemporal_load(ptr4+e+0); + hi = __builtin_nontemporal_load(ptr4+e+1); tmp4[e] = __funnelshift_r(lo, hi, 8*(int(reinterpret_cast(ptr))%4)); } } else if(sizeof(T) == 4) { #pragma unroll for(int e=0; e < 4; e++) - asm("ld.shared.b32 %0,[%1];" : "=r"(tmp4[e]) : "l"(ptr+e)); + tmp4[e] = __builtin_nontemporal_load(ptr+e); } else /*sizeof(T)==8*/ { #pragma unroll for(int e=0; e < 2; e++) - asm("ld.shared.b64 %0,[%1];" : "=l"(tmp8[e]) : "l"(ptr+e)); + tmp8[e] = __builtin_nontemporal_load(ptr+e); } v0 = tmp8[0]; v1 = tmp8[1]; } + +template +__device__ __forceinline__ uint32_t cvta_to_shared(T* ptr) { + return (uint32_t)(uint64_t)(ptr); +} +template +__device__ __forceinline__ uintptr_t cvta_to_global(T* ptr) { + return (uintptr_t)(ptr); +} + +template +__device__ __forceinline__ T* cvta_from_shared(uint32_t shptr) { + return (T*)shptr; +} +template +__device__ __forceinline__ T* cvta_from_global(uintptr_t gptr) { + return (T*)gptr; +} + +//////////////////////////////////////////////////////////////////////////////// +// BytePack: struct of bytes. + +template +union BytePack; +template<> +union BytePack<1> { + uint8_t u8, native; +}; +template<> +union BytePack<2> { + BytePack<1> half[2]; + uint8_t u8[2]; + uint16_t u16, native; +}; +template<> +union BytePack<4> { + BytePack<2> half[2]; + uint8_t u8[4]; + uint16_t u16[2]; + uint32_t u32, native; +}; +template<> +union BytePack<8> { + BytePack<4> half[2]; + uint8_t u8[8]; + uint16_t u16[4]; + uint32_t u32[2]; + uint64_t u64, native; +}; +template<> +union alignas(16) BytePack<16> { + BytePack<8> half[2]; + uint8_t u8[16]; + uint16_t u16[8]; + uint32_t u32[4]; + uint64_t u64[2]; + ulong2 ul2, native; +#ifndef USE_INDIRECT_FUNCTION_CALL + inline __device__ BytePack<16>& operator=(BytePack<16> other) { + u64[0] = other.u64[0]; + u64[1] = other.u64[1]; + return *this; + } +#endif +}; + +template +__device__ __forceinline__ BytePack toPack(T value) { + union { BytePack p; T v; }; + v = value; + return p; +} +template +__device__ __forceinline__ T fromPack(BytePack pack) { + union { BytePack p; T v; }; + p = pack; + return v; +} + +//////////////////////////////////////////////////////////////////////////////// +// Load/store of BytePack using integral addresses. + +template __device__ BytePack ld_global(uintptr_t addr); +template __device__ BytePack ld_volatile_global(uintptr_t addr); +//template __device__ BytePack ld_shared(uint32_t addr); +//template __device__ BytePack ld_volatile_shared(uint32_t addr); +template __device__ void st_global(uintptr_t addr, BytePack value); +//template __device__ void st_shared(uint32_t addr, BytePack value); + +// Used to define implementations for above prototypes. +#define DEFINE_ld_st(bytes, data_cxx_ty, data_ptx_ty, data_reg_ty, space, addr_cxx_ty, addr_reg_ty) \ + template<> \ + __device__ __forceinline__ BytePack ld_##space(addr_cxx_ty addr) { \ + data_cxx_ty tmp; \ + tmp = *((data_cxx_ty *)addr); \ + BytePack ans; \ + ans.native = tmp; \ + return ans; \ + } \ + template<> \ + __device__ __forceinline__ BytePack ld_volatile_##space(addr_cxx_ty addr) { \ + data_cxx_ty tmp; \ + tmp = __builtin_nontemporal_load((data_cxx_ty *)addr); \ + BytePack ans; \ + ans.native = tmp; \ + return ans; \ + } \ + template<> \ + __device__ __forceinline__ void st_##space(addr_cxx_ty addr, BytePack value) { \ + data_cxx_ty tmp = value.native; \ + *((data_cxx_ty *)addr) = tmp; \ + } +// Single-byte types use 4-byte registers since there is no 1-byte register +// character for asm blocks. See https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html#constraints +DEFINE_ld_st(1, uint8_t, b8, r, global, uintptr_t, l) +//DEFINE_ld_st(1, uint32_t, b8, r, shared, uint32_t, r) +DEFINE_ld_st(2, uint16_t, b16, h, global, uintptr_t, l) +//DEFINE_ld_st(2, uint16_t, b16, h, shared, uint32_t, r) +DEFINE_ld_st(4, uint32_t, b32, r, global, uintptr_t, l) +//DEFINE_ld_st(4, uint32_t, b32, r, shared, uint32_t, r) +DEFINE_ld_st(8, uint64_t, b64, l, global, uintptr_t, l) +//DEFINE_ld_st(8, uint64_t, b64, l, shared, uint32_t, r) +#undef DEFINE_ld_st + +#define DEFINE_ld_st_16(space, addr_cxx_ty, addr_reg_ty) \ + template<> \ + __device__ __forceinline__ BytePack<16> ld_##space<16>(addr_cxx_ty addr) { \ + BytePack<16> ans; \ + ans.u64[0] = *((uint64_t*)addr); \ + ans.u64[1] = *((uint64_t*)addr+1); \ + return ans; \ + } \ + template<> \ + __device__ __forceinline__ BytePack<16> ld_volatile_##space<16>(addr_cxx_ty addr) { \ + BytePack<16> ans; \ + ans.u64[0] = __builtin_nontemporal_load((uint64_t*)addr); \ + ans.u64[1] = __builtin_nontemporal_load((uint64_t*)addr+1); \ + return ans; \ + } \ + template<> \ + __device__ __forceinline__ void st_##space<16>(addr_cxx_ty addr, BytePack<16> value) { \ + *((uint64_t*)addr) = value.u64[0]; \ + *((uint64_t*)addr+1) = value.u64[1]; \ + } +DEFINE_ld_st_16(global, uintptr_t, l) +//DEFINE_ld_st_16(shared, uint32_t, r) +#undef DEFINE_ld_st_16 + +//////////////////////////////////////////////////////////////////////////////// +// Atomic load/store using c++ pointers. + +__device__ __forceinline__ uint64_t ld_volatile_global(uint64_t *ptr) { + uint64_t ans; + ans = __builtin_nontemporal_load(ptr); + return ans; +} +__device__ __forceinline__ uint64_t ld_relaxed_sys_global(uint64_t *ptr) { + uint64_t ans; + ans = __builtin_nontemporal_load(ptr); + return ans; +} +__device__ __forceinline__ uint64_t ld_acquire_sys_global(uint64_t *ptr) { + uint64_t ans; + ans = __atomic_load_n(ptr ,__ATOMIC_SEQ_CST); + return ans; +} + +__device__ __forceinline__ void st_volatile_global(uint64_t *ptr, uint64_t val) { + __builtin_nontemporal_store(val, ptr); +} +__device__ __forceinline__ void st_relaxed_sys_global(uint64_t *ptr, uint64_t val) { + __builtin_nontemporal_store(val, ptr); +} +__device__ __forceinline__ void st_release_sys_global(uint64_t *ptr, uint64_t val) { + __atomic_store_n(ptr, val, __ATOMIC_SEQ_CST); +} + +__device__ __forceinline__ void fence_acq_rel_sys() { + //asm volatile("membar.sys;" ::: "memory"); +} +__device__ __forceinline__ void fence_acq_rel_gpu() { + //asm volatile("membar.gl;" ::: "memory"); +} + +//////////////////////////////////////////////////////////////////////////////// +// Multimem stores of BytePack. + +template +__device__ __forceinline__ void multimem_st_global(uintptr_t addr, BytePack val); + +#if __CUDA_ARCH__ >= 900 && CUDART_VERSION >= 12010 +template<> +__device__ __forceinline__ void multimem_st_global<4>(uintptr_t addr, BytePack<4> val) { + asm volatile("multimem.st.global.b32 [%0], %1;" :: "l"(addr), "r"(val.u32) : "memory"); +} +template<> +__device__ __forceinline__ void multimem_st_global<8>(uintptr_t addr, BytePack<8> val) { + asm volatile("multimem.st.global.b64 [%0], %1;" :: "l"(addr), "l"(val.u64) : "memory"); +} +template<> +__device__ __forceinline__ void multimem_st_global<16>(uintptr_t addr, BytePack<16> val) { + asm volatile("multimem.st.global.v4.f32 [%0], {%1,%2,%3,%4};" + :: "l"(addr), "r"(val.u32[0]), "r"(val.u32[1]), "r"(val.u32[2]), "r"(val.u32[3]) + : "memory"); +} +#else +template +__device__ __forceinline__ void multimem_st_global(uintptr_t addr, BytePack val) { + // nop +} +#endif + +#if __CUDA_ARCH__ >= 900 && CUDART_VERSION >= 12010 +// Warp-uniform memory copy from shared address (not generic) to global memory. +// The number of bytes copied is `min(MaxBytes, nBytesAhead)`, a negative value +// is interpeted as zero. EltSize is the guaranteed alignment of the addresses and sizes. +template +__device__ __forceinline__ void copyGlobalShared_WarpUnrolled( + int lane, uintptr_t dstAddr, uint32_t srcAddr, IntBytes nBytesAhead + ) { + static_assert(std::is_signed::value, "`IntBytes` must be a signed integral type."); + int nBytes = min(nBytesAhead, (IntBytes)MaxBytes); + int nFrontBytes = min(nBytes, (16 - int(dstAddr%16))%16); + int nMiddleBytes = (nBytes-nFrontBytes) & -16; + int nBackBytes = (nBytes-nFrontBytes) % 16; + + { int backLane = WARP_SIZE-1 - lane; + bool hasFront = lane*EltSize < nFrontBytes; + bool hasBack = backLane*EltSize < nBackBytes; + int offset = hasFront ? lane*EltSize : (nBytes - (backLane+1)*EltSize); + if (hasFront | hasBack) { + BytePack tmp = ld_shared(srcAddr+offset); + // Can't use multimem_st since it doesn't support EltSize==2 + st_global(dstAddr+offset, tmp); + } + } + + srcAddr += nFrontBytes; + int srcMisalign = EltSize < 4 ? (srcAddr%4) : 0; + srcAddr += -srcMisalign + lane*16; + dstAddr += nFrontBytes + lane*16; + nMiddleBytes -= lane*16; + #pragma unroll + for (int u=0; u < divUp(MaxBytes, WARP_SIZE*16); u++) { + if (nMiddleBytes <= 0) break; + union { + BytePack<4> b4[4]; + BytePack<16> b16; + }; + b4[0] = ld_shared<4>(srcAddr + 0*4); + b4[1] = ld_shared<4>(srcAddr + 1*4); + b4[2] = ld_shared<4>(srcAddr + 2*4); + b4[3] = ld_shared<4>(srcAddr + 3*4); + if (srcMisalign != 0) { + BytePack<4> b4_4 = ld_shared<4>(srcAddr + 4*4); + b4[0].u32 = __funnelshift_r(b4[0].u32, b4[1].u32, srcMisalign*8); + b4[1].u32 = __funnelshift_r(b4[1].u32, b4[2].u32, srcMisalign*8); + b4[2].u32 = __funnelshift_r(b4[2].u32, b4[3].u32, srcMisalign*8); + b4[3].u32 = __funnelshift_r(b4[3].u32, b4_4.u32, srcMisalign*8); + } + if (Multimem) multimem_st_global<16>(dstAddr, b16); + else st_global<16>(dstAddr, b16); + + srcAddr += WARP_SIZE*16; + dstAddr += WARP_SIZE*16; + nMiddleBytes -= WARP_SIZE*16; + } +} +#else +template +__device__ __forceinline__ void copyGlobalShared_WarpUnrolled( + int lane, uintptr_t dstAddr, uint32_t srcAddr, IntBytes nBytesAhead + ) { + // nop +} +#endif + #endif diff --git a/src/collectives/device/primitives.h b/src/collectives/device/primitives.h index 50377e3323..19191b6c12 100644 --- a/src/collectives/device/primitives.h +++ b/src/collectives/device/primitives.h @@ -10,6 +10,7 @@ #include #include "reduce_kernel.h" // for reduction funcs +#include "common_kernel.h" #include "common.h" #define NCCL_SPINS_BEFORE_CHECK_ABORT 1000000 @@ -37,12 +38,13 @@ * to how that protocol operates with a consistent interface so that our * algorithm code can operate protocol parametrically. */ -template +template struct ProtoSimple { static constexpr int Id = NCCL_PROTO_SIMPLE; static constexpr int SlicePerChunk = SlicePerChunk_1; static constexpr int StepPerSlice = StepPerSlice_1; static constexpr int Unroll = Unroll_1; + static constexpr bool NVLS = NVLS_1; // Data bytes (no flags etc) in one step of the fifo queue. __device__ static int calcBytePerStep() { diff --git a/src/collectives/device/prims_ll.h b/src/collectives/device/prims_ll.h index f85b43d612..a0060dae46 100644 --- a/src/collectives/device/prims_ll.h +++ b/src/collectives/device/prims_ll.h @@ -77,11 +77,6 @@ private: #endif } - static inline __device__ uint32_t __funnelshift_r(uint32_t lo, uint32_t hi, uint32_t shift) { - uint64_t val64 = ((uint64_t)lo+((uint64_t)hi<<32))>>(shift&31); - return (uint32_t)val64; - } - uint32_t abort = 0; inline __device__ int checkAbort(int &spins, int send) { @@ -426,18 +421,18 @@ private: } if (SRC) { data = dl.loadFinish(); - if (SrcBuf == Input) data = MULTI().preOp(redOp, data); + if (SrcBuf == Input) data = applyPreOp(redOp, data); } if (RECV) { - data = !SRC ? peerData : MULTI()(redOp, peerData, data); - #pragma unroll + data = !SRC ? peerData : applyReduce(redOp, peerData, data); + #pragma unroll MaxRecv for (int i=1; i < MaxRecv && i < fan.nrecv(); i++) { peerData = readLLFinish(offset, line, i); - data = MULTI()(redOp, peerData, data); + data = applyReduce(redOp, peerData, data); } } - if (postOp) data = MULTI().postOp(redOp, data); + if (postOp) data = applyPostOp(redOp, data); // Send : inter-node, then intra-node, then local if (SEND) { @@ -511,13 +506,13 @@ private: uint64_t dataD; dl.loadBegin(dstElts, eltInLine); dataD = dl.loadFinish(); - dataD = MULTI()(redOp, dataD, data); + dataD = applyReduce(redOp, dataD, data); if (MULTISRCS){ for (int i = 1; i < nsrcs; i++){ dl.loadBegin(srcs[i], eltInLine); srcs[i] += eltPerTrip; data = dl.loadFinish(); - dataD = MULTI()(redOp, dataD, data); + dataD = applyReduce(redOp, dataD, data); } } mscclStoreData(dstElts, dataD, eltInLine); diff --git a/src/collectives/device/prims_ll128.h b/src/collectives/device/prims_ll128.h index d5737bc157..48f7796df8 100644 --- a/src/collectives/device/prims_ll128.h +++ b/src/collectives/device/prims_ll128.h @@ -249,9 +249,9 @@ private: if (SrcBuf == Input) { #pragma unroll for (int u=0; u().preOp(redOp, v[u]); + v[u] = applyPreOp(redOp, v[u]); if (!flagThread) - v[u+1] = MULTI().preOp(redOp, v[u+1]); + v[u+1] = applyPreOp(redOp, v[u+1]); } } } @@ -262,8 +262,8 @@ private: uint64_t* ptr = recvPtr(0)+ll128Offset; #pragma unroll for (int u=0; u()(redOp, vr[u], v[u]) : vr[u]; - v[u+1] = SRC ? MULTI()(redOp, vr[u+1], v[u+1]) : vr[u+1]; + v[u] = SRC ? applyReduce(redOp, vr[u], v[u]) : vr[u]; + v[u+1] = SRC ? applyReduce(redOp, vr[u+1], v[u+1]) : vr[u+1]; } } @@ -283,20 +283,24 @@ private: needReload &= (0 == checkAbort(spins, i, 0)); } while (__any(needReload)); + #pragma unroll + for (int u=0; u()(redOp, vr[u], v[u]); - v[u+1] = MULTI()(redOp, vr[u+1], v[u+1]); + v[u] = applyReduce(redOp, vr[u], v[u]); + v[u+1] = applyReduce(redOp, vr[u+1], v[u+1]); } } } /********************** End Recv ************************/ - if (postOp && !FuncTraits::IsPostOpIdentity) { + if (postOp) { #pragma unroll for (int u=0; u().postOp(redOp, v[u]); - v[u+1] = MULTI().postOp(redOp, v[u+1]); + v[u] = applyPostOp(redOp, v[u]); + v[u+1] = applyPostOp(redOp, v[u+1]); } } @@ -332,14 +336,6 @@ private: __device__ __forceinline__ void GenericOp(intptr_t srcIx, intptr_t dstIx, int nelem, bool postOp) { constexpr int SRC = SrcBuf != -1 ? 1 : 0; constexpr int DST = DstBuf != -1 ? 1 : 0; - static_assert(-1<=SrcBuf && SrcBuf < 2, "Uhoh"); - static_assert(-1<=DstBuf && DstBuf < 2, "Uhoh"); - static_assert(DstBuf!=Input, "Mistake?"); - #if 0 - assert((SrcBuf==-1) == (srcIx==-1)); - assert((DstBuf==-1) == (dstIx==-1)); - #endif - T const *srcPtr = SrcBuf == -1 ? nullptr : userBufs[SrcBuf] + srcIx; T *dstPtr = DstBuf == -1 ? nullptr : userBufs[DstBuf] + dstIx; int wireOffset = WireWordPerSlice*warp + 2*wid; @@ -403,18 +399,18 @@ private: loadRegsFinish(regsD); #pragma unroll for (int u=0; u()(redOp, regs[u], regsD[u]); + regsD[u] = applyReduce(redOp, regs[u], regsD[u]); if (!flagThread) - regsD[u+1] = MULTI()(redOp, regs[u+1], regsD[u+1]); + regsD[u+1] = applyReduce(redOp, regs[u+1], regsD[u+1]); } if (MULTISRCS){ for (int i = 1; i < nsrcs; i++){ loadRegsBegin(regs, srcs[i], eltInSlice); loadRegsFinish(regs); for (int u=0; u()(redOp, regs[u], regsD[u]); + regsD[u] = applyReduce(redOp, regs[u], regsD[u]); if (!flagThread) - regsD[u+1] = MULTI()(redOp, regs[u+1], regsD[u+1]); + regsD[u+1] = applyReduce(redOp, regs[u+1], regsD[u+1]); } } } diff --git a/src/collectives/device/prims_simple.h b/src/collectives/device/prims_simple.h index 324742cb6b..fb5b0e0af9 100644 --- a/src/collectives/device/prims_simple.h +++ b/src/collectives/device/prims_simple.h @@ -13,9 +13,9 @@ #include "msccl/msccl_struct.h" template + int SlicePerChunk, int StepPerSlice, int Unroll, int P2p, bool NVLS> class Primitives< - T, RedOp, Fan, Direct, ProtoSimple, P2p + T, RedOp, Fan, Direct, ProtoSimple, P2p > { static constexpr int MaxRecv = Fan::MaxRecv, MaxSend = Fan::MaxSend; static constexpr int Input=0, Output=1; @@ -30,8 +30,10 @@ class Primitives< SizesFifoEnabled = 0x100, DirectWrite = 0x200, DirectRead = 0x400, - ThreadsSynced = 0x800; - const int tid; + ThreadsSynced = 0x800, + NvlsMinPolling = 0x1000, + NvlsRecv = 0x2000; + const int tid, tidInBlock; int nthreads; int nworkers; const int stepSize; @@ -49,7 +51,7 @@ class Primitives< int volatile *connSizesFifoPtr; // (flags & SizesFifoEnabled) T *directBuff; // !(flags & SizesFifoEnabled) }; - uint64_t volatile *connStepPtr; + uint64_t *connStepPtr; uint64_t connStepCache; // Cache last seen value of (*connStepPtr) uint64_t* barriers; uint64_t* barrier_next; @@ -66,28 +68,15 @@ private: // Don't use barrier 0 as it's used by the final sync inline __device__ void barrier() { -#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__) + flags |= ThreadsSynced; if (nthreads == WARP_SIZE) __syncwarp(); else barrier_by_group(); -#else - if (nthreads == WARP_SIZE) - __syncwarp(); - else - asm volatile("bar.sync %0, %1;" :: "r"(15-group), "r"(nthreads)); -#endif - flags |= ThreadsSynced; } + inline __device__ void subBarrier() { -#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__) barrier(); -#else - if (nworkers == nthreads) - barrier(); - else - asm volatile("bar.sync %0, %1;" :: "r"(8-group), "r"(nworkers)); -#endif } inline __device__ bool checkAbort(int &spins) { @@ -102,6 +91,19 @@ private: return flags & Aborted; } + inline __device__ uint64_t loadStepValue(uint64_t* ptr) { + #if __CUDA_ARCH__ >= 900 && CUDART_VERSION >= 12010 + if (NVLS && (flags & NvlsMinPolling)) { + uint64_t ans; + asm("multimem.ld_reduce.acquire.sys.global.min.u64 %0, [%1];" : "=l"(ans) : "l"(cvta_to_global(ptr))); + return ans; + } + #endif + // volatile is faster than acquire but not as correct. Make sure ReduceOrCopyMulti + // loads data using volatile so it doesn't see stale data in L1. + return atomicAdd((unsigned long long *)ptr, 0); + } + template __device__ __forceinline__ void waitPeer(intptr_t dstIx, intptr_t remoteIx, int offset, int nelts) { const bool isSendNotRecv = (Send && Recv) ? (flags & RoleWaitSend) : Send; @@ -112,7 +114,7 @@ private: int spins = 0; while (connStepCache + (isSendNotRecv ? NCCL_STEPS : 0) < step + StepPerSlice) { __builtin_amdgcn_s_sleep(1); - connStepCache = atomicAdd((unsigned long long *)connStepPtr, 0); + connStepCache = loadStepValue(connStepPtr); if (checkAbort(spins)) break; //if (spins == 0) printf("r=%d b=%d t=%d SPUN OUT got=%d want=%d\n", ncclShmem.comm.rank, blockIdx.x, threadIdx.x, int(connStepCache + (isSendNotRecv ? NCCL_STEPS : 0)), int(step+StepPerSlice)); if (spins == 0) traceData(__LINE__, threadIdx.x, int(connStepCache + (isSendNotRecv ? NCCL_STEPS : 0)), int(step+StepPerSlice)); @@ -153,12 +155,18 @@ private: } template - inline __device__ void postPeer() { + inline __device__ void postPeer(bool dataStored) { if ((flags & Send*RolePostSend) && next_hdp_reg) STORE((unsigned int *)next_hdp_reg, 0x1); if (flags & (Recv*RolePostRecv | Send*RolePostSend)) { step += StepPerSlice; + if (Send && (flags & RolePostSend) && dataStored) +#ifdef __GFX9__ + __asm__ __volatile__("buffer_wbinvl1_vol"); +#else + __threadfence_system(); +#endif STORE(connStepPtr, step); } } @@ -202,6 +210,7 @@ private: // barrier(); // post(); // } // Since we no longer unroll, new branch added here + #pragma unroll 1 do { sliceSize = sliceSize < nelem-offset ? sliceSize : nelem-offset; if (Src && (flags & (SrcBuf==Input ? RoleInput : RoleOutput))) @@ -212,8 +221,13 @@ private: subBarrier(); /* if user abort the kernel, we don't need to actually perform copy/reduce; just set size * to 0 to avoid unnecessary workload. */ - size_t workSize = ncclShmem.aborted ? 0 : sliceSize; - if (DirectRecv && ncclShmem.groups[group].srcs[0] == ncclShmem.groups[group].dsts[0]) { + int workSize = ncclShmem.aborted ? 0 : sliceSize; + if (NVLS && ncclShmem.groups[group].nvlsRecv) { + void* src = ncclShmem.groups[group].srcs[0]; + void* dst = ncclShmem.groups[group].dsts[0]; + copyMultimemMultimem(tid, nworkers, ncclShmem.redOpArgs[0], postOp, src, dst, workSize, + cvta_to_shared(ncclScratchForWarp(tidInBlock/WARP_SIZE))); + } else if (DirectRecv && ncclShmem.groups[group].srcs[0] == ncclShmem.groups[group].dsts[0]) { // We can only have one direct receive. Since srcs[0] == dstPtr+offset, skip one copy if (Send) { @@ -230,11 +244,10 @@ private: } #endif - // (1-Send) is only there to avoid compilation errors in case MaxSend=0 (and Send=0). - ReduceOrCopyMulti - (tid, nworkers, nullptr, false, - 1, (T const**)ncclShmem.groups[group].srcs, - fan.nsend(), (T**)ncclShmem.groups[group].dsts+1, + ReduceOrCopyMulti + (tid, nworkers, /*redArg*/0, /*preOpArgs*/nullptr, /*postOp*/false, + 1, ncclShmem.groups[group].srcs, + fan.nsend(), ncclShmem.groups[group].dsts+1, workSize); #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_PRIM_COLLECT_DATA_PROCESS_TIME) @@ -254,7 +267,6 @@ private: } } else if (DirectSend && !DirectRecv && SrcBuf != Input && ncclShmem.groups[group].dsts[Dst] == nullptr) { // For broadcast in CollNet to do empty send - #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_PRIM_SIMPLE_REDUCE_OR_COPY_MULTI_ENTRY) if (tid == 0) { NpKit::CollectGpuEvent(NPKIT_EVENT_PRIM_SIMPLE_REDUCE_OR_COPY_MULTI_ENTRY, sliceSize*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), @@ -268,10 +280,10 @@ private: } #endif - ReduceOrCopyMulti - (tid, nworkers, ncclShmem.redOpArgs, postOp, - Recv, (T const**)ncclShmem.groups[group].srcs, - Dst, (T**)ncclShmem.groups[group].dsts, + ReduceOrCopyMulti + (tid, nworkers, ncclShmem.redOpArgs[0], nullptr, postOp, + Recv, ncclShmem.groups[group].srcs, + Dst, ncclShmem.groups[group].dsts, workSize); #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_PRIM_COLLECT_DATA_PROCESS_TIME) @@ -289,7 +301,6 @@ private: #endif } else { - #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_PRIM_SIMPLE_REDUCE_OR_COPY_MULTI_ENTRY) if (tid == 0) { NpKit::CollectGpuEvent(NPKIT_EVENT_PRIM_SIMPLE_REDUCE_OR_COPY_MULTI_ENTRY, sliceSize*sizeof(T), 0, NPKIT_GET_GPU_TIMESTAMP(), @@ -303,12 +314,12 @@ private: } #endif - constexpr int PreOpN = SrcBuf != Input ? 0 : - DirectRecv*MaxRecv == NCCL_MAX_DIRECT_ARITY ? (1+NCCL_MAX_DIRECT_ARITY) : 1; - ReduceOrCopyMulti - (tid, nworkers, ncclShmem.redOpArgs, postOp, - Recv*fan.nrecv()+Src, (T const**)ncclShmem.groups[group].srcs, - Send*fan.nsend()+Dst, (T**)ncclShmem.groups[group].dsts, + constexpr int PreOpSrcs = SrcBuf != Input ? 0 : + DirectRecv*MaxRecv == NCCL_MAX_DIRECT_ARITY ? (1+NCCL_MAX_DIRECT_ARITY) : 1; + ReduceOrCopyMulti + (tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp, + Recv*fan.nrecv()+Src, ncclShmem.groups[group].srcs, + Send*fan.nsend()+Dst, ncclShmem.groups[group].dsts, workSize); #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_PRIM_COLLECT_DATA_PROCESS_TIME) @@ -327,13 +338,7 @@ private: } barrier(); // This barrier has a counterpart in following loop -#ifdef __GFX9__ - if (Send && (flags & RolePostSend) && index == 0) __asm__ __volatile__("buffer_wbinvl1_vol"); -#else - if (Send && (flags & RolePostSend) && index == 0) __threadfence_system(); -#endif - __syncwarp(); - postPeer(); + postPeer(0 < sliceSize); offset += sliceSize; slice += 1; } while (slice < SlicePerChunk && offset < nelem); @@ -343,6 +348,7 @@ private: // slices are all empty. Since empty slices are the uncommon case, and // worker perf is the limiter, perf-wise this loop is effectively unentered, // hence just a single branch insn. + #pragma unroll 1 while (slice < SlicePerChunk) { sliceSize = sliceSize < nelem-offset ? sliceSize : nelem-offset; { // Only workers could have Wait roles so we know the slice must be empty @@ -350,13 +356,7 @@ private: waitPeer(0, 0, 0, 0); } barrier(); // Has couterpart in preceding worker-only loop. -#ifdef __GFX9__ - if (Send && (flags & RolePostSend) && sliceSize > 0 && index == 0) __asm__ __volatile__("buffer_wbinvl1_vol"); -#else - if (Send && (flags & RolePostSend) && sliceSize > 0 && index == 0) __threadfence_system(); -#endif - __syncwarp(); - postPeer(); + postPeer(0 < sliceSize); offset += sliceSize; slice += 1; } @@ -371,19 +371,19 @@ private: nsrcs++; if (MULTISRCS){ ReduceOrCopyMulti - (tid, nworkers, ncclShmem.redOpArgs, false, nsrcs, (T const**)srcs, 1, (T**)dsts, nelem); + (tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, nsrcs, (void **)srcs, 1, (void **)dsts, nelem); } else { ReduceOrCopyMulti - (tid, nworkers, ncclShmem.redOpArgs, false, 2, (T const**)srcs, 1, (T**)dsts, nelem); + (tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, 2, (void **)srcs, 1, (void **)dsts, nelem); } } if (COPY){ ReduceOrCopyMulti - (tid, nworkers, ncclShmem.redOpArgs, false, 1, (T const**)srcs, 1, (T**)dsts, nelem); + (tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, 1, (void **)srcs, 1, (void **)dsts, nelem); if (MULTISRCS) { for (int i = 1; i < nsrcs; i++){ ReduceOrCopyMulti - (tid, nworkers, ncclShmem.redOpArgs, false, 1, (T const**)&srcs[i], 1, (T**)&dsts[i], nelem); + (tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, 1, (void **)&srcs[i], 1, (void **)&dsts[i], nelem); } } } @@ -396,44 +396,46 @@ private: // shift: peer offset to avoid all ranks sending to or receiving from same peer template __device__ __forceinline__ void - ScatterGatherOp(intptr_t inpIx, intptr_t outIx, int totalElem, int peerElem, int skip, int shift, bool postOp) { + ScatterGatherOp(intptr_t inpIx, intptr_t outIx, int totalElem, int peerElem, int peerOffset, int skip, int shift, bool postOp) { constexpr int DirectRecv = 1 && Direct && DirectRecv1; constexpr int DirectSend = 1 && Direct && DirectSend1; int offset = 0; // slice offset int sliceSize = stepSize*StepPerSlice; int dataSize = max(DIVUP(peerElem, 16*SlicePerChunk)*16, sliceSize/32); // per-peer slice size + #pragma unroll 1 for (int slice=0; slice(0, inpIx, offset, realSize); subBarrier(); + #pragma unroll 1 // Loop over peers for (int j=0; j= 0 && i >= skip) peerOffset += peerElem; - const T* src0 = (T*)ncclShmem.groups[group].srcs[0] + peerOffset; - int realPeerSize = min(realSize, totalElem-peerOffset); + if (skip >= 0 && i >= skip) pOffset += peerElem; + void* src0 = (T*)ncclShmem.groups[group].srcs[0] + pOffset; + int realPeerSize = min(realSize, totalElem-pOffset); if (realPeerSize > 0 && ncclShmem.groups[group].dsts[i] != nullptr) { - ReduceOrCopyMulti(tid, nworkers, ncclShmem.redOpArgs, false, 1, &src0, 1, (T**)ncclShmem.groups[group].dsts+i, realPeerSize); + ReduceOrCopyMulti(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, false, 1, &src0, 1, ncclShmem.groups[group].dsts+i, realPeerSize); // Mark for threadfence at the end - if (tid == 0) ncclShmem.groups[group].totalSendSize[slice] += realPeerSize; + fenceNeeded |= true; } } } else if (Recv) { if (flags & RoleOutput) ncclShmem.groups[group].dsts[0] = userBuff + outIx + offset; - int peerOffset = index*peerElem; - if (skip >= 0 && index >= skip) peerOffset += peerElem; + int pOffset = index*peerOffset; + if (skip >= 0 && index >= skip) pOffset += peerElem; // Adjust remote index with peer offset in case we are directly pulling from peer's output buffer - waitPeer(outIx, outIx+peerOffset, offset, realSize); + waitPeer(outIx, outIx+pOffset, offset, realSize); subBarrier(); if (DirectRecv && ncclShmem.groups[group].srcs[0] == ncclShmem.groups[group].dsts[0]) { // Since waitPeer sets srcs[0] to output buffer + offset, we are doing a direct-write based recv @@ -441,21 +443,17 @@ private: } else { for (int j=0; j= 0 && i >= skip) peerOffset += peerElem; - T* dst0 = (T*)ncclShmem.groups[group].dsts[0] + peerOffset; - int realPeerSize = min(realSize, totalElem-peerOffset); - if (realPeerSize > 0) ReduceOrCopyMulti(tid, nworkers, ncclShmem.redOpArgs, postOp, 1, (const T**)ncclShmem.groups[group].srcs+i, 1, &dst0, realPeerSize); + pOffset = i*peerOffset; + if (skip >= 0 && i >= skip) pOffset += peerElem; + void* dst0 = (T*)ncclShmem.groups[group].dsts[0] + pOffset; + int realPeerSize = min(realSize, totalElem-pOffset); + if (realPeerSize > 0) ReduceOrCopyMulti(tid, nworkers, ncclShmem.redOpArgs[0], ncclShmem.redOpArgs, postOp, 1, ncclShmem.groups[group].srcs+i, 1, &dst0, realPeerSize); } } } } - barrier(); - // If we indeed send something, threadfence - if (Send && (flags & RolePostSend) && ncclShmem.groups[group].totalSendSize[slice] > 0 && index == 0) - __threadfence_system(); - __syncwarp(); - postPeer(); + fenceNeeded = __any(fenceNeeded); + postPeer(fenceNeeded); offset += realSize; } } @@ -471,25 +469,33 @@ private: } if (flags & RoleWaitRecv) { ncclShmem.groups[group].recvConns[index] = conn; // WaitRecv role saves since that's who needs it in setDataPtrs() + if ((index == 0) && (flags & RoleWaitRecv)) { + if (conn->flags & NCCL_NVLS_MIN_POLL) { + flags |= NvlsMinPolling; + ncclShmem.groups[group].nvlsRecv = 1; + } else { + ncclShmem.groups[group].nvlsRecv = 0; + } + } connStepPtr = conn->tail; - connStepCache = atomicAdd((unsigned long long *)connStepPtr, 0); + connStepCache = loadStepValue(connStepPtr); flags |= (conn->offsFifo != nullptr) ? OffsFifoEnabled : 0; if (Direct) { // User buffers have been registered - if ((conn->direct & (NCCL_IPC_READ|NCCL_IPC_WRITE)) && e != nullptr && e->regUsed) { + if ((conn->flags & (NCCL_IPC_READ|NCCL_IPC_WRITE)) && e != nullptr && e->regUsed) { if (connIndex == 1 && P2p == 0) { flags |= DirectRead; // scatter-reduce use direct pull } else { flags |= (e->direct & NCCL_DIRECT_WRITE) ? DirectWrite : (e->direct & NCCL_DIRECT_READ) ? DirectRead : 0; } - } else if (conn->direct & (NCCL_DIRECT_WRITE|NCCL_DIRECT_READ)) { + } else if (conn->flags & (NCCL_DIRECT_WRITE|NCCL_DIRECT_READ)) { if (connIndex == 1 && P2p == 0) { flags |= DirectRead; // scatter-reduce use direct pull } else { // direct read not allowed in non-register case // otherwise, in one-to-multi send, we could mix empty send and intermediate send - flags |= (conn->direct & NCCL_DIRECT_WRITE) ? DirectWrite : 0; + flags |= (conn->flags & NCCL_DIRECT_WRITE) ? DirectWrite : 0; } } } @@ -511,8 +517,9 @@ private: } if (flags & RoleWaitSend) { ncclShmem.groups[group].sendConns[index] = conn; // WaitSend role saves since that's who needs it in setDataPtrs() + flags |= (conn->flags & NCCL_NVLS_MIN_POLL) ? NvlsMinPolling : 0; connStepPtr = conn->head; - connStepCache = atomicAdd((unsigned long long *)connStepPtr, 0); + connStepCache = loadStepValue(connStepPtr); flags |= (conn->offsFifo != nullptr) ? OffsFifoEnabled : 0; if (flags & OffsFifoEnabled) connOffsFifoPtr = conn->offsFifo; @@ -523,20 +530,20 @@ private: connSizesFifoPtr = conn->sizesFifo; } else if (Direct) { // User buffers have been registered - if ((conn->direct & (NCCL_IPC_READ|NCCL_IPC_WRITE)) && e != nullptr && e->regUsed) { + if ((conn->flags & (NCCL_IPC_READ|NCCL_IPC_WRITE)) && e != nullptr && e->regUsed) { if (connIndex == 1 && P2p == 0) { flags |= DirectRead; // scatter-reduce use direct pull } else { flags |= (e->direct & NCCL_DIRECT_WRITE) ? DirectWrite : (e->direct & NCCL_DIRECT_READ) ? DirectRead : 0; } - } else if (conn->direct & (NCCL_DIRECT_WRITE|NCCL_DIRECT_READ)) { + } else if (conn->flags & (NCCL_DIRECT_WRITE|NCCL_DIRECT_READ)) { if (connIndex == 1 && P2p == 0) { flags |= DirectRead; // scatter-reduce use direct pull } else { // direct read not allowed in non-register case // otherwise, in one-to-multi send, we could mix empty send and intermediate send - flags |= (conn->direct & NCCL_DIRECT_WRITE) ? DirectWrite : 0; + flags |= (conn->flags & NCCL_DIRECT_WRITE) ? DirectWrite : 0; } } } @@ -549,7 +556,7 @@ private: int tid, int nthreads, int const *recvPeers, int const *sendPeers, void const *inputBuf, void *outputBuf, uint64_t redOpArg, uint32_t group=0, struct ncclWorkElem* e = nullptr ): - tid(tid), + tid(tid), tidInBlock(threadIdx.x), stepSize(ncclShmem.comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/sizeof(T)) { // For send operations, we need an extra warp to overlap the threadfence and the copy @@ -566,7 +573,7 @@ private: this->fan = Fan(nrecv, nsend); constexpr int ThreadPerSync = 8; - static_assert(MaxSend < ThreadPerSync && MaxRecv < ThreadPerSync, "Not enough threads to cover all peers"); + static_assert(MaxSend <= ThreadPerSync && MaxRecv <= ThreadPerSync, "Not enough threads to cover all peers"); int g = tid / ThreadPerSync; int ng = nthreads / ThreadPerSync; @@ -726,6 +733,9 @@ private: genericOp<0, 1, 0, 1, Input, Output>(inpIx, outIx, remoteOutIx, eltN, postOp); } + __device__ __forceinline__ void recvSend(int eltN, bool postOp=false) { + genericOp<0, 0, 1, 1, -1, -1>(-1, -1, -1, eltN, postOp); + } __device__ __forceinline__ void recvCopySend(intptr_t outIx, int eltN, bool postOp=false) { genericOp<0, 0, 1, 1, -1, Output>(-1, outIx, -1, eltN, postOp); } @@ -756,25 +766,21 @@ private: } __device__ __forceinline__ void - scatter(intptr_t inpIx, int totalElem, int peerElem, int skip, int shift) { - ScatterGatherOp<0, 0, 0, 1>(inpIx, -1, totalElem, peerElem, skip, shift, /*postOp=*/false); + scatter(intptr_t inpIx, int totalElem, int peerElem, int peerOffset, int skip, int shift) { + ScatterGatherOp<0, 0, 0, 1>(inpIx, -1, totalElem, peerElem, peerOffset, skip, shift, /*postOp=*/false); } __device__ __forceinline__ void - directScatter(intptr_t inpIx, int totalElem, int peerElem, int skip, int shift) { - ScatterGatherOp<0, 1, 0, 1>(inpIx, -1, totalElem, peerElem, skip, shift, /*postOp=*/false); + directScatter(intptr_t inpIx, int totalElem, int peerElem, int peerOffset, int skip, int shift) { + ScatterGatherOp<0, 1, 0, 1>(inpIx, -1, totalElem, peerElem, peerOffset, skip, shift, /*postOp=*/false); } __device__ __forceinline__ void - gather(intptr_t outIx, int totalElem, int peerElem, int skip, int shift, bool postOp=false) { - ScatterGatherOp<0, 0, 1, 0>(-1, outIx, totalElem, peerElem, skip, shift, postOp); + gather(intptr_t outIx, int totalElem, int peerElem, int peerOffset, int skip, int shift, bool postOp=false) { + ScatterGatherOp<0, 0, 1, 0>(-1, outIx, totalElem, peerElem, peerOffset, skip, shift, postOp); } __device__ __forceinline__ void - directGather(intptr_t outIx, int totalElem, int peerElem, int skip, int shift) { - ScatterGatherOp<1, 0, 1, 0>(-1, outIx, totalElem, peerElem, skip, shift, /*postOp=*/false); - } - - __device__ __forceinline__ void recvSend(int eltN) { - genericOp<0, 0, 1, 1, -1, -1>(-1, -1, -1, eltN, /*postOp=*/false); + directGather(intptr_t outIx, int totalElem, int peerElem, int peerOffset, int skip, int shift) { + ScatterGatherOp<1, 0, 1, 0>(-1, outIx, totalElem, peerElem, peerOffset, skip, shift, /*postOp=*/false); } // MSCCL primitives diff --git a/src/collectives/device/reduce_kernel.h b/src/collectives/device/reduce_kernel.h index 88ba7d788f..ab490450ef 100644 --- a/src/collectives/device/reduce_kernel.h +++ b/src/collectives/device/reduce_kernel.h @@ -9,396 +9,447 @@ #ifndef NCCL_REDUCE_KERNEL_H_ #define NCCL_REDUCE_KERNEL_H_ -#include "common_kernel.h" +#include "op128.h" #include #include -template -struct FuncNull { - __device__ FuncNull(uint64_t opArg=0) {} - __device__ T operator()(const T x, const T y) const { - return 0; - } -}; +//////////////////////////////////////////////////////////////////////////////// +// The reduction function classes. All classes must: +// 1. Expose the `EltType` typedef. +// 2. Have constructor taking no arguments (default constructible). +// 3. Have constructor taking `uint64_t opArg`. template -struct FuncSum { - __device__ FuncSum(uint64_t opArg=0) {} - __device__ T operator()(const T x, const T y) const { - return x + y; - } -}; - +struct FuncNull { using EltType = T; __device__ FuncNull(uint64_t opArg=0) {}; }; template -struct FuncProd { - __device__ FuncProd(uint64_t opArg=0) {} - __device__ T operator()(const T x, const T y) const { - return x * y; - } -}; - +struct FuncSum { using EltType = T; __device__ FuncSum(uint64_t opArg=0) {}; }; template -struct FuncMax { - __device__ FuncMax(uint64_t opArg=0) {} - __device__ T operator()(const T x, const T y) const { - return (x < y) ? y : x; - } -}; - +struct FuncProd { using EltType = T; __device__ FuncProd(uint64_t opArg=0) {}; }; template -struct FuncMin { - __device__ FuncMin(uint64_t opArg=0) {} - __device__ T operator()(const T x, const T y) const { - return (x < y) ? x : y; - } -}; +struct FuncMin { using EltType = T; __device__ FuncMin(uint64_t opArg=0) {}; }; +template +struct FuncMax { using EltType = T; __device__ FuncMax(uint64_t opArg=0) {}; }; + +template struct FuncPreMulSum; +template struct FuncSumPostDiv; + +//////////////////////////////////////////////////////////////////////////////// +// Trait classes for reduction functions. Given a function (FuncSum, etc.) +// and a number of elements in a pack, will reduce, preOp, or postOp a pack +// of elements. These classes are intended to be specialized for specific +// combinations of reduction function and pack size. + +template +struct Apply_Reduce /*{ + static BytePack reduce( + Fn fn, BytePack a, BytePack b + ); +}*/; +template +struct Apply_PreOp/*{ + static constexpr bool IsIdentity; + static BytePack preOp(Fn fn, BytePack a); +}*/; +template +struct Apply_PostOp/*{ + static constexpr bool IsIdentity; + static BytePack postOp(Fn fn, BytePack a); +}*/; +template +struct Apply_LoadMultimem/*{ + static constexpr int PackSize; // 0 if not implemented + static BytePack load(Fn fn, uintptr_t addr); +}*/; + +//////////////////////////////////////////////////////////////////////////////// +// Public API for calling the trait classes. These take the data elements as a +// pack of any type, which could be a BytePack or any integral type (uint64_t, +// uint32_t, etc.), and will return a new pack where each element has been +// transformed appropriately. + +template +__device__ __forceinline__ Pack applyReduce(Fn fn, Pack a, Pack b) { + return fromPack( + Apply_Reduce + ::reduce(fn, toPack(a), toPack(b)) + ); +} + +template +__device__ __forceinline__ Pack applyPreOp(Fn fn, Pack a) { + return fromPack( + Apply_PreOp + ::preOp(fn, toPack(a)) + ); +} + +template +__device__ __forceinline__ Pack applyPostOp(Fn fn, Pack a) { + return fromPack( + Apply_PostOp + ::postOp(fn, toPack(a)) + ); +} template -struct FuncTraits { // generic implementation for FuncSum,Prod,Min,Max - static constexpr bool IsPreOpIdentity = true; - static constexpr bool IsPostOpIdentity = true; - - template - __device__ static T preOp(Fn, T x) { return x; } - template - __device__ static T postOp(Fn, T x) { return x; } -}; - -#define NCCL_MASK0 0x00ff00ff -#define NCCL_MASK1 0xff00ff00 -static __device__ uint32_t addChar4(const uint32_t x, const uint32_t y) { - /* This can be used both for signed and unsigned 8-bit addition */ - const uint32_t x0 = x & NCCL_MASK0; - const uint32_t x1 = x & NCCL_MASK1; - const uint32_t y0 = y & NCCL_MASK0; - const uint32_t y1 = y & NCCL_MASK1; - const uint32_t r0 = (x0+y0); - const uint32_t r1 = (x1+y1); - return (r0 & NCCL_MASK0) | (r1 & NCCL_MASK1); +__device__ __forceinline__ BytePack::PackSize> applyLoadMultimem(Fn fn, uintptr_t addr) { + return Apply_LoadMultimem::load(fn, addr); } -template<> -struct FuncSum { - __device__ FuncSum(uint64_t opArg=0) {} - __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { -#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500) - int32_t rv, z=0; - asm("vadd4.s32.s32.s32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z)); - return rv; -#else - return addChar4(x, y); -#endif - } - __device__ int8_t operator()(const int8_t x, const int8_t y) const { - return x+y; - } -}; -template<> -struct FuncSum { - __device__ FuncSum(uint64_t opArg=0) {} - __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { -#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500) - int32_t rv, z=0; - asm("vadd4.u32.u32.u32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z)); - return rv; -#else - return addChar4(x, y); -#endif - } - __device__ uint8_t operator()(const uint8_t x, const uint8_t y) const { - return x+y; +//////////////////////////////////////////////////////////////////////////////// +// Apply_Reduce + +// General recursive definition (EltPerPack > 1). This is how we iterate over +// all elements in a pack of any size, by breaking it into halves. Eventually +// we'll hit a base case (a more specific template specialization which takes +// precedence). +template +struct Apply_Reduce { + template + __device__ static BytePack reduce(Fn fn, BytePack a, BytePack b) { + a.half[0] = Apply_Reduce::reduce(fn, a.half[0], b.half[0]); + a.half[1] = Apply_Reduce::reduce(fn, a.half[1], b.half[1]); + return a; } }; -static __device__ uint32_t mulChar4(const uint32_t x, const uint32_t y) { - /* This can be used both for signed and unsigned 8-bit multiplication */ - union converter { uint32_t storage; char4 a; }; - converter cx, cy, cr; - cx.storage = x; - cy.storage = y; - cr.a.x = cx.a.x * cy.a.x; - cr.a.y = cx.a.y * cy.a.y; - cr.a.z = cx.a.z * cy.a.z; - cr.a.w = cx.a.w * cy.a.w; - return cr.storage; -} - -template<> -struct FuncProd { - __device__ FuncProd(uint64_t opArg=0) {} - __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { - return mulChar4(x, y); - } - __device__ int8_t operator()(const int8_t x, const int8_t y) const { - return x*y; +// Base case definitions (EltPerPack == 1) +template +struct Apply_Reduce, /*EltPerPack=*/1> { + __device__ static BytePack reduce(FuncSum fn, BytePack a, BytePack b) { + return a; } }; -template<> -struct FuncProd { - __device__ FuncProd(uint64_t opArg=0) {} - __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { - return mulChar4(x, y); +template +struct Apply_Reduce, /*EltPerPack=*/1> { + __device__ static BytePack reduce(FuncSum fn, BytePack a, BytePack b) { + return toPack(fromPack(a) + fromPack(b)); } - __device__ uint8_t operator()(const uint8_t x, const uint8_t y) const { - return x*y; +}; +template +struct Apply_Reduce, /*EltPerPack=*/1> { + __device__ static BytePack reduce(FuncProd fn, BytePack a, BytePack b) { + return toPack(fromPack(a) * fromPack(b)); + } +}; +template +struct Apply_Reduce, /*EltPerPack=*/1> { + __device__ static BytePack reduce(FuncMin fn, BytePack a, BytePack b) { + return toPack(min(fromPack(a), fromPack(b))); + } +}; +template +struct Apply_Reduce, /*EltPerPack=*/1> { + __device__ static BytePack reduce(FuncMax fn, BytePack a, BytePack b) { + return toPack(max(fromPack(a), fromPack(b))); } }; +// Optimizations for specfic types and element count combinations: template<> -struct FuncMax { - __device__ FuncMax(uint64_t opArg=0) {} - union converter { uint32_t storage; char4 a; }; - __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { -#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500) - int32_t rv, z=0; - asm("vmax4.s32.s32.s32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z)); - return rv; -#else - converter cx, cy, cr; - cx.storage = x; - cy.storage = y; - cr.a.x = max(cx.a.x, cy.a.x); - cr.a.y = max(cx.a.y, cy.a.y); - cr.a.z = max(cx.a.z, cy.a.z); - cr.a.w = max(cx.a.w, cy.a.w); - return cr.storage; -#endif - } - __device__ int8_t operator()(const int8_t x, const int8_t y) const { - return (x>y) ? x : y; +struct Apply_Reduce, /*EltPerPack=*/4> { + __device__ static BytePack<4> reduce(FuncSum fn, BytePack<4> a, BytePack<4> b) { + constexpr uint32_t lo = 0x00ff00ff; + constexpr uint32_t hi = ~lo; + uint32_t x = a.u32; + uint32_t y = b.u32; + a.u32 = (((x&lo) + (y&lo))&lo) + (((x&hi) + (y&hi))&hi); + return a; } }; template<> -struct FuncMax { - __device__ FuncMax(uint64_t opArg=0) {} - union converter { uint32_t storage; uchar4 a; }; - __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { -#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500) - int32_t rv, z=0; - asm("vmax4.u32.u32.u32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z)); - return rv; -#else - converter cx, cy, cr; - cx.storage = x; - cy.storage = y; - cr.a.x = max(cx.a.x, cy.a.x); - cr.a.y = max(cx.a.y, cy.a.y); - cr.a.z = max(cx.a.z, cy.a.z); - cr.a.w = max(cx.a.w, cy.a.w); - return cr.storage; -#endif - } - __device__ uint8_t operator()(const uint8_t x, const uint8_t y) const { - return (x>y) ? x : y; +struct Apply_Reduce, /*EltPerPack=*/4> { + __device__ static BytePack<4> reduce(FuncSum fn, BytePack<4> a, BytePack<4> b) { + return Apply_Reduce, 4>::reduce(FuncSum(), a, b); } }; -template<> -struct FuncMin { - __device__ FuncMin(uint64_t opArg=0) {} - union converter { uint32_t storage; char4 a; }; - __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { -#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500) - int32_t rv, z=0; - asm("vmin4.s32.s32.s32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z)); - return rv; -#else - converter cx, cy, cr; - cx.storage = x; - cy.storage = y; - cr.a.x = min(cx.a.x, cy.a.x); - cr.a.y = min(cx.a.y, cy.a.y); - cr.a.z = min(cx.a.z, cy.a.z); - cr.a.w = min(cx.a.w, cy.a.w); - return cr.storage; +#if 300 <= __CUDA_ARCH__ && __CUDA_ARCH__ < 500 + template<> + struct Apply_Reduce, /*EltPerPack=*/4> { + __device__ static BytePack<4> reduce(FuncMin fn, BytePack<4> a, BytePack<4> b) { + uint32_t z=0; + asm("vmin4.u32.u32.u32 %0, %1, %2, %3;" : "=r"(a.u32) : "r"(a.u32), "r"(b.u32), "r"(z)); + return a; + } + }; + template<> + struct Apply_Reduce, /*EltPerPack=*/4> { + __device__ static BytePack<4> reduce(FuncMin fn, BytePack<4> a, BytePack<4> b) { + int32_t z=0; + asm("vmin4.s32.s32.s32 %0, %1, %2, %3;" : "=r"(a.u32) : "r"(a.u32), "r"(b.u32), "r"(z)); + return a; + } + }; + template<> + struct Apply_Reduce, /*EltPerPack=*/4> { + __device__ static BytePack<4> reduce(FuncMax fn, BytePack<4> a, BytePack<4> b) { + uint32_t z=0; + asm("vmax4.u32.u32.u32 %0, %1, %2, %3;" : "=r"(a.u32) : "r"(a.u32), "r"(b.u32), "r"(z)); + return a; + } + }; + template<> + struct Apply_Reduce, /*EltPerPack=*/4> { + __device__ static BytePack<4> reduce(FuncMax fn, BytePack<4> a, BytePack<4> b) { + int32_t z=0; + asm("vmax4.s32.s32.s32 %0, %1, %2, %3;" : "=r"(a.u32) : "r"(a.u32), "r"(b.u32), "r"(z)); + return a; + } + }; #endif - } - __device__ int8_t operator()(const int8_t x, const int8_t y) const { - return (x -struct FuncMin { - __device__ FuncMin(uint64_t opArg=0) {} - union converter { uint32_t storage; uchar4 a; }; - __device__ uint32_t operator()(const uint32_t x, const uint32_t y) const { -#if (__CUDA_ARCH__ >= 300) && (__CUDA_ARCH__ < 500) - int32_t rv, z=0; - asm("vmin4.u32.u32.u32 %0, %1, %2, %3;" : "=r"(rv) : "r"(x), "r"(y), "r"(z)); - return rv; -#else - converter cx, cy, cr; - cx.storage = x; - cy.storage = y; - cr.a.x = min(cx.a.x, cy.a.x); - cr.a.y = min(cx.a.y, cy.a.y); - cr.a.z = min(cx.a.z, cy.a.z); - cr.a.w = min(cx.a.w, cy.a.w); - return cr.storage; -#endif - } - __device__ uint8_t operator()(const uint8_t x, const uint8_t y) const { - return (x -struct FuncSum { - __device__ FuncSum(uint64_t opArg=0) {} - __device__ half2 operator()(const half2 x, const half2 y) const { +#define SPECIALIZE_REDUCE(Fn, T, EltPerPack, Vec, expr_of_x_y) \ + template<> \ + struct Apply_Reduce, EltPerPack> { \ + __device__ __forceinline__ static BytePack reduce( \ + Fn fn, BytePack a, BytePack b \ + ) { \ + Vec x = fromPack(a); \ + Vec y = fromPack(b); \ + return toPack(expr_of_x_y); \ + } \ + }; + #if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610 - return __hadd2(x, y); + SPECIALIZE_REDUCE(FuncSum, half, 1, half, __hadd(x, y)) + SPECIALIZE_REDUCE(FuncSum, half, 2, half2, __hadd2(x, y)) + SPECIALIZE_REDUCE(FuncProd, half, 1, half, __hmul(x, y)) + SPECIALIZE_REDUCE(FuncProd, half, 2, half2, __hmul2(x, y)) #else - float2 fx, fy, fr; - fx = __half22float2(x); - fy = __half22float2(y); - fr.x = fx.x + fy.x; - fr.y = fx.y + fy.y; - return __float22half2_rn(fr); + SPECIALIZE_REDUCE(FuncSum, half, 1, half, __float2half(__half2float(x) + __half2float(y))) + SPECIALIZE_REDUCE(FuncProd, half, 1, half, __float2half(__half2float(x) * __half2float(y))) #endif + +#if __CUDA_ARCH__ >= 800 + SPECIALIZE_REDUCE(FuncMin, half, 1, half, __hmin(x, y)) + SPECIALIZE_REDUCE(FuncMin, half, 2, half2, __hmin2(x, y)) + SPECIALIZE_REDUCE(FuncMax, half, 1, half, __hmax(x, y)) + SPECIALIZE_REDUCE(FuncMax, half, 2, half2, __hmax2(x, y)) +#else + SPECIALIZE_REDUCE(FuncMin, half, 1, half, __float2half(fminf(__half2float(x), __half2float(y)))) + SPECIALIZE_REDUCE(FuncMax, half, 1, half, __float2half(fmaxf(__half2float(x), __half2float(y)))) +#endif + +#if defined(RCCL_BFLOAT16) +#if __CUDA_ARCH__ >= 800 + SPECIALIZE_REDUCE(FuncSum, __nv_bfloat16, 1, __nv_bfloat16, __hadd(x, y)) + SPECIALIZE_REDUCE(FuncSum, __nv_bfloat16, 2, __nv_bfloat162, __hadd2(x, y)) + SPECIALIZE_REDUCE(FuncProd, __nv_bfloat16, 1, __nv_bfloat16, __hmul(x, y)) + SPECIALIZE_REDUCE(FuncProd, __nv_bfloat16, 2, __nv_bfloat162, __hmul2(x, y)) + SPECIALIZE_REDUCE(FuncMin, __nv_bfloat16, 1, __nv_bfloat16, __hmin(x, y)) + SPECIALIZE_REDUCE(FuncMin, __nv_bfloat16, 2, __nv_bfloat162, __hmin2(x, y)) + SPECIALIZE_REDUCE(FuncMax, __nv_bfloat16, 1, __nv_bfloat16, __hmax(x, y)) + SPECIALIZE_REDUCE(FuncMax, __nv_bfloat16, 2, __nv_bfloat162, __hmax2(x, y)) +#else + SPECIALIZE_REDUCE(FuncSum, rccl_bfloat16, 1, rccl_bfloat16, (rccl_bfloat16)((float)(x) + (float)(y))) + SPECIALIZE_REDUCE(FuncProd, rccl_bfloat16, 1, rccl_bfloat16, (rccl_bfloat16)((float)(x) * (float)(y))) + SPECIALIZE_REDUCE(FuncMin, rccl_bfloat16, 1, rccl_bfloat16, (rccl_bfloat16)(fminf((float)(x), (float)(y)))) + SPECIALIZE_REDUCE(FuncMax, rccl_bfloat16, 1, rccl_bfloat16, (rccl_bfloat16)(fmaxf((float)(x), (float)(y)))) +#endif +#endif + +#undef SPECIALIZE_REDUCE + +//////////////////////////////////////////////////////////////////////////////// +// Apply_PreOp + +// General recursive definition (EltPerPack > 1) +template +struct Apply_PreOp { + static constexpr bool IsIdentity = Apply_PreOp::IsIdentity; + template + __device__ static BytePack preOp(Fn fn, BytePack a) { + #if __cpp_if_constexpr + if constexpr(!IsIdentity) { + #else + if (!IsIdentity) { + #endif + // The `if (!IsIdentity)` condition is not strictly necessary, but it may help + // compiler in that it won't have to tear a register apart for no reason + // just to put it back together again. + a.half[0] = Apply_PreOp::preOp(fn, a.half[0]); + a.half[1] = Apply_PreOp::preOp(fn, a.half[1]); + } + return a; } - __device__ half operator()(const half x, const half y) const { +}; +// Base case definition (EltPerPack == 1), by default is identity function. +template +struct Apply_PreOp { + static constexpr bool IsIdentity = true; + template + __device__ static BytePack preOp(Fn fn, BytePack a) { + return a; + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Apply_PostOp + +// General recursive definition (EltPerPack > 1) +template +struct Apply_PostOp { + static constexpr bool IsIdentity = Apply_PostOp::IsIdentity; + template + __device__ static BytePack postOp(Fn fn, BytePack a) { + #if __cpp_if_constexpr + if constexpr(!IsIdentity) { + #else + if (!IsIdentity) { + #endif + // The `if (!IsIdentity)` condition is not strictly necessary, but it may help + // compiler in that it won't have to tear a register apart for no reason + // just to put it back together again. + a.half[0] = Apply_PostOp::postOp(fn, a.half[0]); + a.half[1] = Apply_PostOp::postOp(fn, a.half[1]); + } + return a; + } +}; +// Base case definition (EltPerPack == 1), by default is identity function. +template +struct Apply_PostOp { + static constexpr bool IsIdentity = true; + template + __device__ static BytePack postOp(Fn fn, BytePack a) { + return a; + } +}; + + +//////////////////////////////////////////////////////////////////////////////// +// FuncPreMulSum + +// General definition for all integral types, float, and double. +template +struct FuncPreMulSum { + using EltType = T; + T scalar; + __device__ FuncPreMulSum(uint64_t opArg=0) { + union { uint64_t u64; T val; }; + u64 = opArg; + scalar = val; + } +}; + +template<> +struct FuncPreMulSum { + using EltType = half; #if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610 - return __hadd(x, y); -#else - return __float2half( __half2float(x) + __half2float(y) ); -#endif + half2 scalar; + __device__ FuncPreMulSum(uint64_t opArg=0) { + union { uint64_t u64; half val; }; + u64 = opArg; + scalar.x = val; + scalar.y = val; } +#else + float scalar; + __device__ FuncPreMulSum(uint64_t opArg=0) { + union { uint64_t u64; half val; }; + u64 = opArg; + scalar = __half2float(val); + } +#endif }; #if defined(RCCL_BFLOAT16) -template<> -struct FuncSum { - __device__ FuncSum(uint64_t opArg=0) {} - __device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const { - return (rccl_bfloat16)((float)x + (float)y); - } -}; + template<> + struct FuncPreMulSum { + using EltType = rccl_bfloat16; + #if __CUDA_ARCH__ >= 800 + __nv_bfloat162 scalar; + __device__ FuncPreMulSum(uint64_t opArg=0) { + union { uint64_t u64; __nv_bfloat16 val; }; + u64 = opArg; + scalar.x = val; + scalar.y = val; + } + #else + float scalar; + __device__ FuncPreMulSum(uint64_t opArg=0) { + union { uint64_t u64; rccl_bfloat16 val; }; + u64 = opArg; + scalar = (float)(val); + } + #endif + }; #endif +template +struct Apply_Reduce, /*EltPerPack=*/1> { + __device__ static BytePack reduce(FuncPreMulSum fn, BytePack a, BytePack b) { + // FuncPreMulSum reduce dispatches to FuncSum. + return Apply_Reduce, 1>::reduce(FuncSum(), a, b); + } +}; + +// PreOp of FuncPreMulSum for integral types, float, and double. +template +struct Apply_PreOp, /*EltPerPack=*/1> { + static constexpr bool IsIdentity = false; + __device__ static BytePack preOp(FuncPreMulSum fn, BytePack a) { + return toPack(fromPack(a) * fn.scalar); + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Apply_PreOp of FuncPreMulSum for float16. + template<> -struct FuncProd { - __device__ FuncProd(uint64_t opArg=0) {} - __device__ half2 operator()(const half2 x, const half2 y) const { +struct Apply_PreOp, /*EltPerPack=*/1> { + static constexpr bool IsIdentity = false; + __device__ static BytePack preOp(FuncPreMulSum fn, BytePack a) { + #if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610 + return toPack(__hmul(fromPack(a), fn.scalar.x)); + #else + return toPack(__float2half(__half2float(fromPack(a)) * fn.scalar)); + #endif + } +}; #if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610 - return __hmul2(x, y); -#else - float2 fx, fy, fr; - fx = __half22float2(x); - fy = __half22float2(y); - fr.x = fx.x * fy.x; - fr.y = fx.y * fy.y; - return __float22half2_rn(fr); + template<> + struct Apply_PreOp, /*EltPerPack=*/2> { + static constexpr bool IsIdentity = false; + __device__ static BytePack preOp(FuncPreMulSum fn, BytePack a) { + return toPack(__hmul2(fromPack(a), fn.scalar)); + } + }; #endif - } - __device__ half operator()(const half x, const half y) const { -#if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610 - return __hmul(x, y); -#else - return __float2half( __half2float(x) * __half2float(y) ); -#endif - } -}; + +//////////////////////////////////////////////////////////////////////////////// +// Apply_PreOp of FuncPreMulSum for bfloat16. #if defined(RCCL_BFLOAT16) -template<> -struct FuncProd { - __device__ FuncProd(uint64_t opArg=0) {} - __device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const { - return (rccl_bfloat16)((float)x * (float)y); - } -}; + template<> + struct Apply_PreOp, /*EltPerPack=*/1> { + static constexpr bool IsIdentity = false; + __device__ static BytePack preOp( + FuncPreMulSum fn, BytePack a + ) { + #if __CUDA_ARCH__ >= 800 + return toPack<__nv_bfloat16>(__hmul(fromPack<__nv_bfloat16>(a), fn.scalar.x)); + #else + return toPack((rccl_bfloat16)((float)(fromPack(a)) * fn.scalar)); + #endif + } + }; + #if __CUDA_ARCH__ >= 800 + template<> + struct Apply_PreOp, /*EltPerPack=*/2> { + static constexpr bool IsIdentity = false; + __device__ static BytePack preOp( + FuncPreMulSum<__nv_bfloat16> fn, BytePack a + ) { + return toPack<__nv_bfloat162>(__hmul2(fromPack<__nv_bfloat162>(a), fn.scalar)); + } + }; + #endif #endif -template<> -struct FuncMax { - __device__ FuncMax(uint64_t opArg=0) {} - __device__ half2 operator()(const half2 x, const half2 y) const { - float2 fx, fy, fr; - fx = __half22float2(x); - fy = __half22float2(y); - fr.x = fmaxf(fx.x, fy.x); - fr.y = fmaxf(fx.y, fy.y); - return __float22half2_rn(fr); - } - __device__ half operator()(const half x, const half y) const { - float fx, fy, fm; - fx = __half2float(x); - fy = __half2float(y); - fm = fmaxf(fx, fy); - return __float2half(fm); - } -}; - -#if defined(RCCL_BFLOAT16) -template<> -struct FuncMax { - __device__ FuncMax(uint64_t opArg=0) {} - __device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const { - return (float)x < (float)y ? y : x; - } -}; -#endif - -template<> -struct FuncMin { - __device__ FuncMin(uint64_t opArg=0) {} - __device__ half2 operator()(const half2 x, const half2 y) const { - float2 fx, fy, fr; - fx = __half22float2(x); - fy = __half22float2(y); - fr.x = fminf(fx.x, fy.x); - fr.y = fminf(fx.y, fy.y); - return __float22half2_rn(fr); - } - __device__ half operator()(const half x, const half y) const { - float fx, fy, fm; - fx = __half2float(x); - fy = __half2float(y); - fm = fminf(fx, fy); - return __float2half(fm); - } -}; - -#if defined(RCCL_BFLOAT16) -template<> -struct FuncMin { - __device__ FuncMin(uint64_t opArg=0) {} - __device__ rccl_bfloat16 operator()(const rccl_bfloat16 x, const rccl_bfloat16 y) const { - return (float)x < (float)y ? x : y; - } -}; -#endif - -template<> -struct FuncMax { - __device__ FuncMax(uint64_t opArg=0) {} - __device__ float operator()(float x, float y) const { - return fmaxf(x, y); - } -}; -template<> -struct FuncMin { - __device__ FuncMin(uint64_t opArg=0) {} - __device__ float operator()(float x, float y) const { - return fminf(x, y); - } -}; - -template<> -struct FuncMax { - __device__ FuncMax(uint64_t opArg=0) {} - __device__ double operator()(double x, double y) const { - return fmax(x, y); - } -}; -template<> -struct FuncMin { - __device__ FuncMin(uint64_t opArg=0) {} - __device__ double operator()(double x, double y) const { - return fmin(x, y); - } -}; +//////////////////////////////////////////////////////////////////////////////// +// FuncSumPostDiv template struct IsFloatingPoint: std::false_type {}; @@ -414,182 +465,128 @@ template<> struct IsFloatingPoint: std::true_type {}; template::value> -struct FuncSumPostDiv; +struct FuncSumPostDiv_IntOnly; template -struct FuncSumPostDiv: FuncSum { - static constexpr bool IsPreOpIdentity = true; - static constexpr bool IsPostOpIdentity = false; - int n; - __device__ FuncSumPostDiv(uint64_t opArg): n(opArg) {} - // inherits FuncSum::operator() - __device__ T preOp(T x) const { return x; } - __device__ T postOp(T x) const { return T(x/n); } +struct FuncSumPostDiv: FuncSumPostDiv_IntOnly { + __device__ FuncSumPostDiv(uint64_t opArg=0): + FuncSumPostDiv_IntOnly(opArg) { + } }; template -struct FuncSumPostDiv { +struct FuncSumPostDiv_IntOnly: FuncSum { + using EltType = T; + int divisor; + __device__ FuncSumPostDiv_IntOnly(uint64_t opArg=0): divisor(opArg) {} +}; + +template +struct FuncSumPostDiv_IntOnly { static_assert(sizeof(T)!=sizeof(T), "FuncSumPostDiv is only for implementing ncclAvg on integral types."); }; template -struct FuncPreMulSum: FuncSum { // integral T since all floats are specialized below - static constexpr bool IsPreOpIdentity = false; - static constexpr bool IsPostOpIdentity = true; - T scale; - __device__ FuncPreMulSum(uint64_t opArg) { scale = *(T*)&opArg; } - // inherits FuncSum::operator() - __device__ T preOp(T x) const { return x*scale; } - __device__ T postOp(T x) const { return x; } -}; - -template<> -struct FuncPreMulSum: FuncSum { - static constexpr bool IsPreOpIdentity = false; - static constexpr bool IsPostOpIdentity = true; - double scale; - __device__ FuncPreMulSum(uint64_t opArg) { - scale = *(double*)&opArg; - } - // inherits FuncSum::operator() - __device__ double preOp(double x) const { - return IsPreOpIdentity ? x : x*scale; - } - __device__ double postOp(double x) const { - return IsPostOpIdentity ? x : x*scale; +struct Apply_Reduce, /*EltPerPack=*/1>: + Apply_Reduce, 1> { + __device__ static BytePack reduce(FuncSumPostDiv fn, BytePack a, BytePack b) { + // FuncSumPostDiv reduce dispatches to FuncSum. + return Apply_Reduce, 1>::reduce(FuncSum(), a, b); } }; -template<> -struct FuncPreMulSum: FuncSum { - static constexpr bool IsPreOpIdentity = false; - static constexpr bool IsPostOpIdentity = true; - float scale; - __device__ FuncPreMulSum(uint64_t opArg) { - scale = *(float*)&opArg; - } - // inherits FuncSum::operator() - __device__ float preOp(float x) const { - return IsPreOpIdentity ? x : x*scale; - } - __device__ float postOp(float x) const { - return IsPostOpIdentity ? x : x*scale; - } -}; - -template<> -struct FuncPreMulSum: FuncSum { - // Change these to switch between all prescale, all postscale, or both by sqrt(N). - // Obviously, the only invalid combination is both true. An improvement would be - // make this parameterized as a build time setting and passed here through - // preprocessor definitions. - static constexpr bool IsPreOpIdentity = false; - static constexpr bool IsPostOpIdentity = true; - -#if __CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610 - half2 scale; - __device__ FuncPreMulSum(uint64_t opArg) { - scale.x = *(half*)&opArg; - scale.y = scale.x; - } - // inherits FuncSum::operator() - __device__ half preOp(half x) const { - return IsPreOpIdentity ? x : __hmul(x, scale.x); - } - __device__ half2 preOp(half2 x) const { - return IsPreOpIdentity ? x : __hmul2(x, scale); - } - __device__ half postOp(half x) const { - return IsPostOpIdentity ? x : __hmul(x, scale.x); - } - __device__ half2 postOp(half2 x) const { - return IsPostOpIdentity ? x : __hmul2(x, scale); - } -#else - float scale; - __device__ FuncPreMulSum(uint64_t opArg) { - scale = __half2float(*(half*)&opArg); - } - // inherits FuncSum::operator() - __device__ half preOp(half x) const { - return IsPreOpIdentity ? x : __float2half(__half2float(x)*scale); - } - __device__ half2 preOp(half2 x) const { - if (IsPreOpIdentity) - return x; - else { - float2 a = __half22float2(x); - a.x *= scale; - a.y *= scale; - return __float22half2_rn(a); - } - } - __device__ half postOp(half x) const { - return IsPostOpIdentity ? x : __float2half(__half2float(x)*scale); - } - __device__ half2 postOp(half2 x) const { - if (IsPostOpIdentity) - return x; - else { - float2 a = __half22float2(x); - a.x *= scale; - a.y *= scale; - return __float22half2_rn(a); - } - } -#endif -}; - -#if defined(RCCL_BFLOAT16) -template<> -struct FuncPreMulSum: FuncSum { - // Change these to switch between all prescale, all postscale, or both by sqrt(N). - // Obviously, the only invalid combination is both true. An improvement would be - // make this parameterized as a build time setting and passed here through - // preprocessor definitions. - static constexpr bool IsPreOpIdentity = false; - static constexpr bool IsPostOpIdentity = true; - - float scale; - __device__ FuncPreMulSum(uint64_t opArg) { - scale = *(rccl_bfloat16*)&opArg; - } - // inherits FuncSum::operator() - __device__ rccl_bfloat16 preOp(rccl_bfloat16 x) const { - return IsPreOpIdentity ? x : (rccl_bfloat16)((float)x*scale); - } - __device__ rccl_bfloat16 postOp(rccl_bfloat16 x) const { - return IsPostOpIdentity ? x : (rccl_bfloat16)((float)x*scale); - } -}; -#endif - template -struct FuncTraits> { - static constexpr bool IsPreOpIdentity = FuncPreMulSum::IsPreOpIdentity; - static constexpr bool IsPostOpIdentity = FuncPreMulSum::IsPostOpIdentity; - - template - __device__ static U preOp(FuncPreMulSum fn, U x) { - return fn.preOp(x); - } - template - __device__ static U postOp(FuncPreMulSum fn, U x) { - return fn.postOp(x); +struct Apply_PostOp, /*EltPerPack=*/1> { + static constexpr bool IsIdentity = false; + __device__ static BytePack postOp(FuncSumPostDiv fn, BytePack a) { + return toPack(fromPack(a) / fn.divisor); } }; -template -struct FuncTraits> { - static constexpr bool IsPreOpIdentity = FuncSumPostDiv::IsPreOpIdentity; - static constexpr bool IsPostOpIdentity = FuncSumPostDiv::IsPostOpIdentity; - template - __device__ static U preOp(FuncSumPostDiv fn, U x) { - return fn.preOp(x); - } - template - __device__ static U postOp(FuncSumPostDiv fn, U x) { - return fn.postOp(x); - } +//////////////////////////////////////////////////////////////////////////////// +// Apply_LoadMultimem + +template +struct Apply_LoadMultimem { + static constexpr int PackSize = 0; // Indicates not implemented }; + +#define SIZEOF_BytePack_field_u16 2 +#define PTX_REG_BytePack_field_u16 "h" + +#define SIZEOF_BytePack_field_u32 4 +#define PTX_REG_BytePack_field_u32 "r" + +#define SIZEOF_BytePack_field_u64 8 +#define PTX_REG_BytePack_field_u64 "l" + +#define DEFINE_Apply_LoadMultimem(Fn, T, op, ptx_ty, pack_field) \ + template<> \ + struct Apply_LoadMultimem> { \ + static constexpr int PackSize = 1*(SIZEOF_BytePack_field_##pack_field); \ + __device__ static BytePack load(Fn fn, uintptr_t addr) { \ + BytePack ans; \ + asm("multimem.ld_reduce.global." #op "." #ptx_ty " %0, [%1];" \ + : "=" PTX_REG_BytePack_field_##pack_field(ans.pack_field) \ + : "l"(addr)); \ + return ans; \ + } \ + }; +#define DEFINE_Apply_LoadMultimem_v4(Fn, T, op, ptx_ty, pack_field) \ + template<> \ + struct Apply_LoadMultimem> { \ + static constexpr int PackSize = 4*(SIZEOF_BytePack_field_##pack_field); \ + __device__ static BytePack load(Fn fn, uintptr_t addr) { \ + BytePack ans; \ + asm("multimem.ld_reduce.global." #op ".v4." #ptx_ty " {%0,%1,%2,%3}, [%4];" \ + : "=" PTX_REG_BytePack_field_##pack_field(ans.pack_field[0]), \ + "=" PTX_REG_BytePack_field_##pack_field(ans.pack_field[1]), \ + "=" PTX_REG_BytePack_field_##pack_field(ans.pack_field[2]), \ + "=" PTX_REG_BytePack_field_##pack_field(ans.pack_field[3]) \ + : "l"(addr)); \ + return ans; \ + } \ + }; + +#if __CUDA_ARCH__ >= 900 && CUDART_VERSION >= 12010 + DEFINE_Apply_LoadMultimem(FuncSum, uint32_t, add, u32, u32) + DEFINE_Apply_LoadMultimem(FuncMin, uint32_t, min, u32, u32) + DEFINE_Apply_LoadMultimem(FuncMax, uint32_t, max, u32, u32) + + DEFINE_Apply_LoadMultimem(FuncSum, int32_t, add, s32, u32) + DEFINE_Apply_LoadMultimem(FuncMin, int32_t, min, s32, u32) + DEFINE_Apply_LoadMultimem(FuncMax, int32_t, max, s32, u32) + + DEFINE_Apply_LoadMultimem(FuncSum, uint64_t, add, u64, u64) + DEFINE_Apply_LoadMultimem(FuncMin, uint64_t, min, u64, u64) + DEFINE_Apply_LoadMultimem(FuncMax, uint64_t, max, u64, u64) + + DEFINE_Apply_LoadMultimem(FuncSum, int64_t, add, u64, u64) + DEFINE_Apply_LoadMultimem(FuncMin, int64_t, min, s64, u64) + DEFINE_Apply_LoadMultimem(FuncMax, int64_t, max, s64, u64) + + DEFINE_Apply_LoadMultimem_v4(FuncSum, float, add, f32, u32) + + DEFINE_Apply_LoadMultimem(FuncSum, double, add, f64, u64) + + DEFINE_Apply_LoadMultimem_v4(FuncSum, half, add, f16x2, u32) + DEFINE_Apply_LoadMultimem_v4(FuncMin, half, min, f16x2, u32) + DEFINE_Apply_LoadMultimem_v4(FuncMax, half, max, f16x2, u32) + + #if defined(__CUDA_BF16_TYPES_EXIST__) + DEFINE_Apply_LoadMultimem_v4(FuncSum, __nv_bfloat16, add, bf16x2, u32) + DEFINE_Apply_LoadMultimem_v4(FuncMin, __nv_bfloat16, min, bf16x2, u32) + DEFINE_Apply_LoadMultimem_v4(FuncMax, __nv_bfloat16, max, bf16x2, u32) + #endif +#endif + +#undef DEFINE_Apply_LoadMultimem +#undef DEFINE_Apply_LoadMultimem_v4 +#undef SIZEOF_BytePack_field_u64 +#undef PTX_REG_BytePack_field_u64 +#undef SIZEOF_BytePack_field_u32 +#undef PTX_REG_BytePack_field_u32 +#undef SIZEOF_BytePack_field_u16 +#undef PTX_REG_BytePack_field_u16 + #endif // REDUCE_KERNEL_H_ diff --git a/src/collectives/device/reduce_scatter.h b/src/collectives/device/reduce_scatter.h index 212444dc28..cf4278485c 100644 --- a/src/collectives/device/reduce_scatter.h +++ b/src/collectives/device/reduce_scatter.h @@ -92,3 +92,45 @@ struct RunWorkElement(args); } }; + +template +struct RunWorkElement { + __device__ __forceinline__ void run(ncclWorkElem *args) { + const int tid = threadIdx.x; + const int bid = args->bid; + const int nChannels = args->nChannels; + struct ncclNvls* nvls = &ncclShmem.channel.nvls; + const ssize_t chunkSize = int(args->lastChunkSize); + const ssize_t size = args->count; + const ssize_t loopSize = nChannels*chunkSize; + + const int nThreadsScatter = 128 + WARP_SIZE; + const int nThreadsReduce = 384; + const int tidEndScatter = nThreadsScatter; + const int tidEndReduce = tidEndScatter + nThreadsReduce; + + using Proto = ProtoSimple<1, 1>; + + if (tid < tidEndScatter) { + // Scatter + int group = (0*Proto::MaxGroupWidth) | (0<<16); + Primitives, /*Direct=*/0, Proto, 0> + prims(tid, nThreadsScatter, NULL, nvls->up, args->sendbuff, NULL, args->redOpArg, group, args); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = gridOffset + bid*chunkSize; + int nelem = min(chunkSize, size-offset); + prims.scatter(offset, nvls->nHeads*size, nelem, size, -1, 0); + } + } else if (tid < tidEndReduce) { + int group = (3*Proto::MaxGroupWidth) | (1<<16); + // Reduce through MC + Primitives, /*Direct=*/0, Proto, 0> + prims(tid-tidEndScatter, nThreadsReduce, &nvls->down, NULL, NULL, args->recvbuff, args->redOpArg, group, args); + for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { + ssize_t offset = gridOffset + bid*chunkSize; + int nelem = min(chunkSize, size-offset); + prims.recv(offset, nelem); + } + } + } +}; diff --git a/src/collectives/device/sendrecv.h b/src/collectives/device/sendrecv.h index 408ad1b488..17bb25b39b 100644 --- a/src/collectives/device/sendrecv.h +++ b/src/collectives/device/sendrecv.h @@ -17,7 +17,7 @@ struct RunWork { template __device__ void runSend(const int tid, const int nthreads, const int group, struct ncclWorkElemP2p* args) { void* buff = reinterpret_cast(uintptr_t(args->buffHi32)<<32 | args->buffLo32); - size_t count = reinterpret_cast(size_t(args->countHi32)<<32 | args->countLo32); + ssize_t count = reinterpret_cast(size_t(args->countHi32)<<32 | args->countLo32); #if defined(ENABLE_NPKIT) bool isNpKitThread = (tid == 0); @@ -43,7 +43,8 @@ struct RunWork { } #endif - ReduceOrCopyMulti(tid, nthreads, nullptr, false, 1, (const T**)&buff, 1, (T**)&recvBuff, count); + ReduceOrCopyMulti + (tid, nthreads, 0, nullptr, false, 1, &buff, 1, &recvBuff, count); #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_PRIM_SIMPLE_REDUCE_OR_COPY_MULTI_EXIT) if (isNpKitThread) { @@ -59,6 +60,8 @@ struct RunWork { } #endif + ReduceOrCopyMulti + (tid, nthreads, 0, nullptr, false, 1, &buff, 1, &recvBuff, count); } } else { int chunkSize = args->chunkSize/sizeof(T); diff --git a/src/debug.cc b/src/debug.cc index 5955a6eb3f..560c1d26a0 100644 --- a/src/debug.cc +++ b/src/debug.cc @@ -74,6 +74,8 @@ void ncclDebugInit() { mask = NCCL_ALLOC; } else if (strcasecmp(subsys, "CALL") == 0) { mask = NCCL_CALL; + } else if (strcasecmp(subsys, "NVLS") == 0) { + mask = NCCL_NVLS; } else if (strcasecmp(subsys, "ALL") == 0) { mask = NCCL_ALL; } diff --git a/src/enqueue.cc b/src/enqueue.cc index d01817a0e9..591606fa33 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -29,58 +29,55 @@ struct ncclKernelMatch { typedef void(*ncclKern_t)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead); // Must be consistent with the ncclFuncSet enum -static ncclKernelMatch const ncclKerns[4] = { +static ncclKernelMatch const ncclKerns[2] = { {(void *)NCCL_KERN_NAME(SendRecv, RING, SIMPLE, Sum, int8_t), true}, {(void *)NCCL_KERN_NAME_DEBUG(SendRecv, RING, SIMPLE, Sum, int8_t), true}, - {(void *)NCCL_KERN_NAME_LL128(SendRecv, RING, SIMPLE, Sum, int8_t), true}, - {(void *)NCCL_KERN_NAME_LL128_DEBUG(SendRecv, RING, SIMPLE, Sum, int8_t), true}, }; static ncclResult_t computeColl(struct ncclInfo* info /* input */, int* workFuncIndex, struct ncclWorkElem* work, struct ncclProxyOp* proxyOp /* output */); -// Determine the maximum kernel stack size of all CUDA kernels -size_t ncclKernMaxLocalSize() { - ncclResult_t res = ncclSuccess; - int numNcclKerns = sizeof(ncclKerns)/sizeof(ncclKerns[0]); - cudaFuncAttributes attr = {0}; - size_t max = 0; - for (int i = 0; i < numNcclKerns; i++) { - if (ncclKerns[i].kernelFn != nullptr) { - CUDACHECKGOTO(cudaFuncGetAttributes(&attr, reinterpret_cast(ncclKerns[i].kernelFn)), res, error); - if (attr.localSizeBytes > max) max = attr.localSizeBytes; +NCCL_PARAM(L1SharedMemoryCarveout, "L1_SHARED_MEMORY_CARVEOUT", 0); + +// Returns maximum kernel stack size of all CUDA kernels +ncclResult_t ncclInitKernelsForDevice(int cudaArch, size_t* maxStackSize) { + constexpr int KernelCount = sizeof(ncclKerns)/sizeof(ncclKerns[0]); + ncclResult_t result = ncclSuccess; + + if (maxStackSize) *maxStackSize = 0; + int carveout = ncclParamL1SharedMemoryCarveout(); + + // Keep track if we already visited a function pointer. + void* lru[2] = {nullptr, nullptr}; + for (int i=0; i < KernelCount; i++) { + void* fn = ncclKerns[i].kernelFn; + if (fn == lru[0] || fn == lru[1]) goto next_kernel; + lru[1] = lru[0]; + lru[0] = fn; + + if (maxStackSize) { + cudaFuncAttributes attr = {0}; + CUDACHECKGOTO(cudaFuncGetAttributes(&attr, fn), result, ignore0); + if (attr.localSizeBytes > *maxStackSize) *maxStackSize = attr.localSizeBytes; + ignore0:; } + + if (carveout) { + CUDACHECKGOTO(cudaFuncSetAttribute(fn, + cudaFuncAttributePreferredSharedMemoryCarveout, carveout), + result, ignore1); + ignore1:; + } + + if (ncclShmemDynamicSize(cudaArch) != 0) { + CUDACHECKGOTO(cudaFuncSetAttribute(fn, + cudaFuncAttributeMaxDynamicSharedMemorySize, ncclShmemDynamicSize(cudaArch)), + result, next_kernel); + } + next_kernel:; } - -error: - return (res != ncclSuccess) ? 0 : max; + return result; } -// Determine kernel stack size from index -size_t ncclKernLocalSize(int i) { - ncclResult_t res = ncclSuccess; - int numNcclKerns = sizeof(ncclKerns)/sizeof(ncclKerns[0]); - cudaFuncAttributes attr = {0}; - if (i < numNcclKerns) - CUDACHECKGOTO(cudaFuncGetAttributes(&attr, reinterpret_cast(ncclKerns[i].kernelFn)), res, error); - -error: - return (res != ncclSuccess) ? 0 : attr.localSizeBytes; -} - - -// Set shared memory carveout for the nccl kernels -ncclResult_t ncclKernSetSharedMemoryCarveout(int carveOut) { - ncclResult_t res = ncclSuccess; - int numNcclKerns = sizeof(ncclKerns)/sizeof(ncclKerns[0]); - for (int i = 0; i < numNcclKerns; i++) { - CUDACHECKGOTO(cudaFuncSetAttribute((const void *)ncclKerns[i].kernelFn, cudaFuncAttributePreferredSharedMemoryCarveout, carveOut), res, error); - } - -error: - return res; -} - - /*****************************************************************************/ /* Launch system : synchronization and CUDA kernel launch */ /*****************************************************************************/ @@ -211,10 +208,9 @@ static ncclResult_t addProxyOpIfNeeded(struct ncclComm* comm, struct ncclKernelP static ncclResult_t addCollToPlan( struct ncclComm* comm, struct ncclKernelPlan* plan, int* nWorkBudget, int funcIndex, struct ncclWorkElem const* workElem, struct ncclProxyOp const* proxyOp, - int nBid, size_t bytes, bool regBufUsed, void* regBufSend[], void* regBufRecv[] + int nCollChannels, int nBid, size_t bytes, bool regBufUsed, void* regBufSend[], void* regBufRecv[] ) { struct ncclKernelPlan::Channel *chans = plan->channels; - int nCollChannels = comm->nChannels; // Choose the `nBid` least loaded channels to do the work. This ensures // all bids go to different channels in case they need to synchronize. @@ -231,9 +227,7 @@ static ncclResult_t addCollToPlan( } } // Sort in the rest of the channels. If a channel has less work than the max - // member of least[], replace that member and compute the new max. The optimal - // algorithm uses a max-heap, but for our small sizes I suspect the better - // asymptotic complexity would be swamped by the increased instruction complexity. + // member of least[], replace that member and compute the new max. for (int c=nBid; c < nCollChannels; c++) { if (chans[c].collBytes < maxBytesInLeast) { least[maxIndexInLeast] = c; @@ -507,8 +501,9 @@ static ncclResult_t scheduleCollTasksToPlan( info.sliceSteps = head->sliceSteps; NCCLCHECK(ncclInfoSetDerived(&info, comm->nRanks)); if (nAggOps > 1) { + int maxChannels = aggInfo.algorithm == NCCL_ALGO_NVLS ? comm->nvlsChannels : comm->nChannels; info.nChannels = DIVUP(info.nBytes, bytePerChannel[collNetSupport]); - info.nChannels = std::max(1, std::min(info.nChannels, comm->nChannels)); + info.nChannels = std::max(1, std::min(info.nChannels, maxChannels)); info.algorithm = aggInfo.algorithm; info.protocol = aggInfo.protocol; info.nThreads = aggInfo.nThreads; @@ -531,8 +526,9 @@ static ncclResult_t scheduleCollTasksToPlan( NCCLCHECK(registerIntraNodeBuffers(comm, plan, &info, ®BufUsed, regBufSend, regBufRecv)); } + int maxChannels = info.algorithm == NCCL_ALGO_NVLS ? comm->nvlsChannels : comm->nChannels; NCCLCHECK(addCollToPlan(comm, plan, nWorkBudget, workFuncIndex, &workElem, &proxyOp, - info.nChannels, info.nBytes, regBufUsed, regBufSend, regBufRecv)); + maxChannels, info.nChannels, info.nBytes, regBufUsed, regBufSend, regBufRecv)); tasks->nTasksColl -= 1; tasks->collBytesTotal -= info.nBytes; ncclIntruQueueDequeue(&tasks->collQueue); @@ -830,7 +826,7 @@ static void HIPRT_CB hostStreamPlanCallback(void *plan_) { struct ncclKernelPlan* plan = (struct ncclKernelPlan*)plan_; ncclResult_t result = hostStreamPlanTask(plan->comm, plan); if (result != ncclSuccess) { - WARN("hostStreamPlanCallback() failed : %s\n", ncclGetErrorString(result)); + WARN("hostStreamPlanCallback() failed : %s", ncclGetErrorString(result)); } } @@ -943,7 +939,7 @@ ncclResult_t ncclLaunchPrepare(struct ncclComm* comm) { CUDACHECK(hipStreamWaitEvent(tasks->streams->stream, comm->doneEvent, 0)); } - if (persistent || comm->persistentRefs != 0) { + if (persistent || comm->persistentRefs != 0 || ncclCudaLaunchBlocking) { // We have to launch host tasks to push proxy args. We are careful to only // do this if necessary since host tasks impose a high performance cost in CUDA. bool acquired = false; @@ -984,12 +980,6 @@ ncclResult_t ncclLaunchKernelBefore_NoUncapturedCuda(struct ncclComm* comm, stru return ncclSuccess; } -#if CUDART_VERSION >= 11080 -#define NCCL_MAX_CGA_CLUSTER_SIZE 8 -#define NCCL_CGA_CLUSTER_SIZE_SM90 4 -NCCL_PARAM(CGAClusterSize, "CGA_CLUSTER_SIZE", -2); -#endif - #if CUDART_VERSION >= 12000 // NCCL uses the "Remote" Mem Sync domain by default NCCL_PARAM(MemSyncDomain, "MEM_SYNC_DOMAIN", cudaLaunchMemSyncDomainRemote); @@ -1001,6 +991,7 @@ ncclResult_t ncclLaunchKernel(struct ncclComm* comm, struct ncclKernelPlan* plan cudaStream_t launchStream = tasks->streams->stream; dim3 grid = {(unsigned)plan->channelCount, 1, 1}; dim3 block = {(unsigned)plan->threadPerBlock, 1, 1}; + size_t smem = ncclShmemDynamicSize(comm->cudaArch); void *args[3] = {&comm->devComm, &plan->channelMask, &plan->workHead}; if (tasks->numStreams == 1) { CUDACHECK(hipExtLaunchKernel(plan->kernelFn, grid, block, args, 0, tasks->streams->stream, NULL, comm->doneEvent, 0)); @@ -1013,19 +1004,7 @@ ncclResult_t ncclLaunchKernel(struct ncclComm* comm, struct ncclKernelPlan* plan NCCLCHECK(ncclCudaDriverVersion(&driverVersion)); if (driverVersion >= 11080) { int compCap = comm->compCap; - unsigned int clusterSize = (compCap == 90) ? NCCL_CGA_CLUSTER_SIZE_SM90 : 0; - if (ncclParamCGAClusterSize() != -2) { - clusterSize = ncclParamCGAClusterSize(); - if (clusterSize > NCCL_MAX_CGA_CLUSTER_SIZE) { - static bool warned = false; - if (warned == false) { - WARN("NCCL_CGA_CLUSTER_SIZE value %d is too big. Limiting value to %d.", - clusterSize, NCCL_MAX_CGA_CLUSTER_SIZE); - warned = true; - } - clusterSize = NCCL_MAX_CGA_CLUSTER_SIZE; - } - } + unsigned int clusterSize = (compCap == 90) ? comm->cgaClusterSize : 0; cudaLaunchConfig_t launchConfig = {0}; cudaLaunchAttribute launchAttrs[3]; @@ -1057,6 +1036,7 @@ ncclResult_t ncclLaunchKernel(struct ncclComm* comm, struct ncclKernelPlan* plan #endif launchConfig.gridDim = grid; launchConfig.blockDim = block; + launchConfig.dynamicSmemBytes = smem; launchConfig.attrs = launchAttrs; launchConfig.numAttrs = attrs; launchConfig.stream = launchStream; @@ -1066,12 +1046,12 @@ ncclResult_t ncclLaunchKernel(struct ncclComm* comm, struct ncclKernelPlan* plan } #endif // Standard kernel launch - CUDACHECK(cudaLaunchKernel(fn, grid, block, args, 0, launchStream)); + CUDACHECK(cudaLaunchKernel(fn, grid, block, args, smem, launchStream)); return ncclSuccess; } ncclResult_t ncclLaunchKernelAfter_NoCuda(struct ncclComm* comm, struct ncclKernelPlan* plan) { - if (comm->persistentRefs == 0) { // implies !plan->persistent + if (!(plan->persistent || comm->persistentRefs != 0 || ncclCudaLaunchBlocking)) { // If this isn't being captured and there aren't any CUDA graphs alive // then we don't need to do our proxyOp pushing on the host stream. NCCLCHECK(hostStreamPlanTask(comm, plan)); @@ -1146,6 +1126,8 @@ static ncclResult_t getAlgoInfo(struct ncclInfo* info, int collNetTypeSupport, i int nAlgos = NCCL_NUM_ALGORITHMS; for (int a=0; adatatype, info->opFull.op)) continue; + for (int p=0; palgorithm == NCCL_ALGO_NVLS) { + // NVLS should not need more than 16 channels to get peak BW. + nc = comm->nvlsChannels; } else { // Ring/Tree channel tuning while (info->nBytes < nc*nt*threadThreshold) { @@ -1198,6 +1183,7 @@ static ncclResult_t getAlgoInfo(struct ncclInfo* info, int collNetTypeSupport, i if (info->algorithm == NCCL_ALGO_TREE) nt += 3*WARP_SIZE; if (info->algorithm == NCCL_ALGO_COLLNET_DIRECT) nt += 3*WARP_SIZE; if (info->algorithm == NCCL_ALGO_COLLNET_CHAIN) nt += 3*WARP_SIZE; + if (info->algorithm == NCCL_ALGO_NVLS) nt = NCCL_MAX_NTHREADS; } nt = nt/WARP_SIZE < 3 ? 3*WARP_SIZE : nt; #endif @@ -1245,6 +1231,7 @@ static ncclResult_t getPatternInfo(struct ncclInfo* info) { info->pattern = ncclPatternRing; break; case ncclFuncAllReduce: info->pattern = + info->algorithm == NCCL_ALGO_NVLS ? ncclPatternNvls : info->algorithm == NCCL_ALGO_COLLNET_DIRECT ? ncclPatternCollnetDirect : info->algorithm == NCCL_ALGO_COLLNET_CHAIN ? ncclPatternCollnetChain : info->algorithm == NCCL_ALGO_TREE ? ncclPatternTreeUpDown : @@ -1264,6 +1251,7 @@ static ncclResult_t getLoopInfo(struct ncclInfo* info) { case ncclPatternPipelineFrom: case ncclPatternPipelineTo: case ncclPatternCollnetChain: + case ncclPatternNvls: info->nstepsPerLoop = info-> nchunksPerLoop = 1; break; case ncclPatternCollnetDirect: info->nstepsPerLoop = 1; info->nchunksPerLoop = info->comm->channels[0].collnetDirect.nHeads; break; @@ -1292,13 +1280,6 @@ comp_next: // Set nstepsPerLoop and nchunksPerLoop NCCLCHECK(getPatternInfo(info)); NCCLCHECK(getLoopInfo(info)); - if (info->comm->topo->pivotA2ANumBiRings == 3 ) { - if (ncclTypeSize(info->datatype)*info->count > 131072) { - work->pad_0 = 1; - } else { - work->pad_0 = 2; - } - } work->sendbuff = info->sendbuff; work->recvbuff = info->recvbuff; work->root = info->root; @@ -1359,6 +1340,14 @@ comp_next: while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collnetChain.depth*8 && chunkSize > 65536) chunkSize /= 2; while (info->nBytes / (info->nChannels*chunkSize) < info->comm->channels[0].collnetChain.depth && chunkSize > 32768) chunkSize /= 2; work->lastChunkSize = chunkSize / ncclTypeSize(info->datatype); + } else if (info->algorithm == NCCL_ALGO_NVLS) { + if (chunkSize > 131072) chunkSize = 131072; + // Use uint64_t so that concurrentOps*chunkSize*X does not overflow + uint64_t concurrentOps = info->nChannels*info->comm->channels[0].nvls.nHeads; + if ((info->nBytes < (32 * (concurrentOps*chunkSize))) && (chunkSize > 65536)) chunkSize = 65536; + if ((info->nBytes < (8 * (concurrentOps*chunkSize))) && (chunkSize > 32768)) chunkSize = 32768; + if ((info->nBytes < (2 * (concurrentOps*chunkSize))) && (chunkSize > 16384)) chunkSize = 16384; + work->lastChunkSize = chunkSize / ncclTypeSize(info->datatype); } else if (info->protocol == NCCL_PROTO_LL) { const ssize_t sliceSize = stepSize*sizeof(uint64_t)/sizeof(union ncclLLFifoLine); const ssize_t loopSize = info->nChannels*info->nchunksPerLoop*(ssize_t)sliceSize; @@ -1671,6 +1660,11 @@ ncclResult_t ncclRedOpDestroy(ncclRedOp_t op, ncclComm_t comm) { WARN("ncclRedOpDestroy : operator is garbage."); return ncclInvalidArgument; } + if (comm == NULL) { + WARN("ncclRedOpDestroy : invalid communicator passed."); + return ncclInvalidArgument; + } + int ix = int(ncclUserRedOpMangle(comm, op)) - int(ncclNumOps); if (comm->userRedOpCapacity <= ix || comm->userRedOps[ix].freeNext != -1) { WARN("ncclRedOpDestroy : operator unknown to this communicator."); diff --git a/src/graph/connect.cc b/src/graph/connect.cc index 3ba3f8806d..617ea78e5f 100644 --- a/src/graph/connect.cc +++ b/src/graph/connect.cc @@ -5,25 +5,6 @@ * See LICENSE.txt for license information ************************************************************************/ -/* - * Code for binary tree based on the same function available in Open MPI - * File: ompi/mca/coll/base/coll_base_topo.c - * - * Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana - * University Research and Technology - * Corporation. All rights reserved. - * Copyright (c) 2004-2015 The University of Tennessee and The University - * of Tennessee Research Foundation. All rights - * reserved. - * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart, - * University of Stuttgart. All rights reserved. - * Copyright (c) 2004-2005 The Regents of the University of California. - * All rights reserved. - * Copyright (c) 2015 Research Organization for Information Science - * and Technology (RIST). All rights reserved. - */ - - #include "comm.h" #include "graph.h" #include "trees.h" @@ -95,279 +76,6 @@ ncclResult_t ncclTopoPreset(struct ncclComm* comm, return ncclSuccess; } -static int calculate_level (int rank) -{ - int level, num; - if( rank < 0 ) return -1; - for( level = 0, num = 0; num <= rank; level++ ) { - num += 1<nChannels; - int localRanks = 0; - for (int i=0; itopo->nodes[GPU].count; i++) { - localRanks += comm->topo->nodes[GPU].nodes[i].gpu.nRanksPerGpu; - } - - for (int c=0; cchannels+c; - // Only the first rank on a GPU can be a treeRoot - int treeRoot = comm->topo->nodes[GPU].nodes[c%comm->topo->nodes[GPU].count].gpu.rank[0]; - - channel->binTree.up = -1; - channel->binTree.down[0] = -1; - channel->binTree.down[1] = -1; - channel->binTree.down[2] = -1; - - /* - * Shift all ranks by root, so that the algorithm can be - * designed as if root would be always 0 - * shiftedrank should be used in calculating distances - * and position in tree - */ - int shiftedrank = comm->rank - treeRoot; - if (shiftedrank < 0 ) { - shiftedrank += localRanks; - } - - /* calculate my level */ - int level = calculate_level (shiftedrank); - int delta = 1<binTree.down[i] = (schild+treeRoot)%localRanks; - } - } - - /* find my parent */ - int slimit = calculate_num_nodes_up_to_level (level); - int sparent = shiftedrank; - if (sparent < 2) { - sparent = 0; - } - else { - while (sparent >= slimit) { - sparent -= delta/2; - } - } - if (comm->rank != treeRoot) { - channel->binTree.up = (sparent+treeRoot)%localRanks; - } - } - - return ncclSuccess; -} - -#define NUM_HAYABUSA_TREES 2 -static bool hayabusa_tree_matrix_is_init=false; -static int hayabusa_tree_matrix[NUM_HAYABUSA_TREES][16][4]; - -static void hayabusa_tree_matrix_init() -{ - if (hayabusa_tree_matrix_is_init) - return; - - // index = rank of proc, child0, child1, child2, parent - // channel 0: root is 15 - hayabusa_tree_matrix[0][0][0] = 1; - hayabusa_tree_matrix[0][0][1] = -1; - hayabusa_tree_matrix[0][0][2] = -1; - hayabusa_tree_matrix[0][0][3] = 4; - - hayabusa_tree_matrix[0][1][0] = -1; - hayabusa_tree_matrix[0][1][1] = -1; - hayabusa_tree_matrix[0][1][2] = -1; - hayabusa_tree_matrix[0][1][3] = 0; - - hayabusa_tree_matrix[0][2][0] = 3; - hayabusa_tree_matrix[0][2][1] = -1; - hayabusa_tree_matrix[0][2][2] = -1; - hayabusa_tree_matrix[0][2][3] = 6; - - hayabusa_tree_matrix[0][3][0] = -1; - hayabusa_tree_matrix[0][3][1] = -1; - hayabusa_tree_matrix[0][3][2] = -1; - hayabusa_tree_matrix[0][3][3] = 2; - - hayabusa_tree_matrix[0][4][0] = 0; - hayabusa_tree_matrix[0][4][1] = -1; - hayabusa_tree_matrix[0][4][2] = -1; - hayabusa_tree_matrix[0][4][3] = 5; - - hayabusa_tree_matrix[0][5][0] = 4; - hayabusa_tree_matrix[0][5][1] = -1; - hayabusa_tree_matrix[0][5][2] = -1; - hayabusa_tree_matrix[0][5][3] = 14; - - hayabusa_tree_matrix[0][6][0] = 2; - hayabusa_tree_matrix[0][6][1] = 7; - hayabusa_tree_matrix[0][6][2] = -1; - hayabusa_tree_matrix[0][6][3] = 14; - - hayabusa_tree_matrix[0][7][0] = -1; - hayabusa_tree_matrix[0][7][1] = -1; - hayabusa_tree_matrix[0][7][2] = -1; - hayabusa_tree_matrix[0][7][3] = 6; - - hayabusa_tree_matrix[0][8][0] = -1; - hayabusa_tree_matrix[0][8][1] = -1; - hayabusa_tree_matrix[0][8][2] = -1; - hayabusa_tree_matrix[0][8][3] = 9; - - hayabusa_tree_matrix[0][9][0] = 13; - hayabusa_tree_matrix[0][9][1] = 8; - hayabusa_tree_matrix[0][9][2] = -1; - hayabusa_tree_matrix[0][9][3] = 11; - - hayabusa_tree_matrix[0][10][0] = -1; - hayabusa_tree_matrix[0][10][1] = -1; - hayabusa_tree_matrix[0][10][2] = -1; - hayabusa_tree_matrix[0][10][3] = 11; - - hayabusa_tree_matrix[0][11][0] = 9; - hayabusa_tree_matrix[0][11][1] = 10; - hayabusa_tree_matrix[0][11][2] = -1; - hayabusa_tree_matrix[0][11][3] = 15; - - hayabusa_tree_matrix[0][12][0] = -1; - hayabusa_tree_matrix[0][12][1] = -1; - hayabusa_tree_matrix[0][12][2] = -1; - hayabusa_tree_matrix[0][12][3] = 13; - - hayabusa_tree_matrix[0][13][0] = 12; - hayabusa_tree_matrix[0][13][1] = -1; - hayabusa_tree_matrix[0][13][2] = -1; - hayabusa_tree_matrix[0][13][3] = 9; - - hayabusa_tree_matrix[0][14][0] = 5; - hayabusa_tree_matrix[0][14][1] = 6; - hayabusa_tree_matrix[0][14][2] = -1; - hayabusa_tree_matrix[0][14][3] = 15; - - hayabusa_tree_matrix[0][15][0] = 14; - hayabusa_tree_matrix[0][15][1] = 11; - hayabusa_tree_matrix[0][15][2] = -1; - hayabusa_tree_matrix[0][15][3] = -1; - - //Channel 1: root is 6 - hayabusa_tree_matrix[1][0][0] = -1; - hayabusa_tree_matrix[1][0][1] = -1; - hayabusa_tree_matrix[1][0][2] = -1; - hayabusa_tree_matrix[1][0][3] = 1; - - hayabusa_tree_matrix[1][1][0] = 5; - hayabusa_tree_matrix[1][1][1] = 0; - hayabusa_tree_matrix[1][1][2] = -1; - hayabusa_tree_matrix[1][1][3] = 3; - - hayabusa_tree_matrix[1][2][0] = -1; - hayabusa_tree_matrix[1][2][1] = -1; - hayabusa_tree_matrix[1][2][2] = -1; - hayabusa_tree_matrix[1][2][3] = 3; - - hayabusa_tree_matrix[1][3][0] = 1; - hayabusa_tree_matrix[1][3][1] = 2; - hayabusa_tree_matrix[1][3][2] = -1; - hayabusa_tree_matrix[1][3][3] = 7; - - hayabusa_tree_matrix[1][4][0] = -1; - hayabusa_tree_matrix[1][4][1] = -1; - hayabusa_tree_matrix[1][4][2] = -1; - hayabusa_tree_matrix[1][4][3] = 5; - - hayabusa_tree_matrix[1][5][0] = 4; - hayabusa_tree_matrix[1][5][1] = -1; - hayabusa_tree_matrix[1][5][2] = -1; - hayabusa_tree_matrix[1][5][3] = 1; - - hayabusa_tree_matrix[1][6][0] = 7; - hayabusa_tree_matrix[1][6][1] = 13; - hayabusa_tree_matrix[1][6][2] = -1; - hayabusa_tree_matrix[1][6][3] = -1; - - hayabusa_tree_matrix[1][7][0] = 3; - hayabusa_tree_matrix[1][7][1] = 15; - hayabusa_tree_matrix[1][7][2] = -1; - hayabusa_tree_matrix[1][7][3] = 6; - - hayabusa_tree_matrix[1][8][0] = 9; - hayabusa_tree_matrix[1][8][1] = -1; - hayabusa_tree_matrix[1][8][2] = -1; - hayabusa_tree_matrix[1][8][3] = 12; - - hayabusa_tree_matrix[1][9][0] = -1; - hayabusa_tree_matrix[1][9][1] = -1; - hayabusa_tree_matrix[1][9][2] = -1; - hayabusa_tree_matrix[1][9][3] = 8; - - hayabusa_tree_matrix[1][10][0] = -1; - hayabusa_tree_matrix[1][10][1] = -1; - hayabusa_tree_matrix[1][10][2] = -1; - hayabusa_tree_matrix[1][10][3] = 11; - - hayabusa_tree_matrix[1][11][0] = 10; - hayabusa_tree_matrix[1][11][1] = -1; - hayabusa_tree_matrix[1][11][2] = -1; - hayabusa_tree_matrix[1][11][3] = 15; - - hayabusa_tree_matrix[1][12][0] = 8; - hayabusa_tree_matrix[1][12][1] = -1; - hayabusa_tree_matrix[1][12][2] = -1; - hayabusa_tree_matrix[1][12][3] = 13; - - hayabusa_tree_matrix[1][13][0] = 12; - hayabusa_tree_matrix[1][13][1] = -1; - hayabusa_tree_matrix[1][13][2] = -1; - hayabusa_tree_matrix[1][13][3] = 6; - - hayabusa_tree_matrix[1][14][0] = -1; - hayabusa_tree_matrix[1][14][1] = -1; - hayabusa_tree_matrix[1][14][2] = -1; - hayabusa_tree_matrix[1][14][3] = 15; - - hayabusa_tree_matrix[1][15][0] = 11; - hayabusa_tree_matrix[1][15][1] = 14; - hayabusa_tree_matrix[1][15][2] = -1; - hayabusa_tree_matrix[1][15][3] = 7; - - hayabusa_tree_matrix_is_init = true; -} - -static void set_channel_info(int c, int rank, struct ncclChannel *channel) -{ - channel->binTree.down[0] = hayabusa_tree_matrix[c%NUM_HAYABUSA_TREES][rank][0]; - channel->binTree.down[1] = hayabusa_tree_matrix[c%NUM_HAYABUSA_TREES][rank][1]; - channel->binTree.down[2] = hayabusa_tree_matrix[c%NUM_HAYABUSA_TREES][rank][2]; - channel->binTree.up = hayabusa_tree_matrix[c%NUM_HAYABUSA_TREES][rank][3]; -} - -ncclResult_t ncclBinaryTreeHayabusaPostset(struct ncclComm* comm, - struct ncclTopoGraph* treeGraph) { - int nChannels = comm->nChannels; - - hayabusa_tree_matrix_init(); - - for (int c=0; cchannels+c; - - set_channel_info(c, comm->localRank, channel); - } - - return ncclSuccess; -} - ncclResult_t ncclTreeBasePostset(struct ncclComm* comm, struct ncclTopoGraph* treeGraph) { int nChannels = comm->nChannels; diff --git a/src/graph/paths.cc b/src/graph/paths.cc index 3d7e052a2d..cbd529bda0 100644 --- a/src/graph/paths.cc +++ b/src/graph/paths.cc @@ -486,7 +486,7 @@ ncclResult_t ncclTopoGetIntermediateRank(struct ncclTopoSystem* system, int rank type = node->type; } if (type != GPU) { - WARN("Could not find intermediate GPU between GPU rank %d and NIC %d\n", rank, netDev); + WARN("Could not find intermediate GPU between GPU rank %d and NIC %d", rank, netDev); return ncclInternalError; } *intermediateRank = node->gpu.rank[0]; @@ -802,6 +802,7 @@ static int nextPow2(int v) { } ncclResult_t ncclTopoComputeP2pChannels(struct ncclComm* comm) { + /* here we already honor comm->max/minCTAs for p2pnChannels. */ comm->p2pnChannels = std::min(comm->nChannels, (int)ncclParamMaxP2pNChannels()); comm->p2pnChannels = std::max(comm->p2pnChannels, (int)ncclParamMinP2pNChannels()); int minChannels = comm->p2pnChannels; @@ -842,7 +843,6 @@ ncclResult_t ncclTopoComputeP2pChannels(struct ncclComm* comm) { for (int b=1, mb=(comm->p2pnChannels>>1); bp2pnChannels; b<<=1, mb>>=1) if (c & b) mirror |= mb; comm->p2pChannels[c] = mirror; } - INFO(NCCL_INIT, "%d coll channels, %d p2p channels, %d p2p channels per peer", comm->nChannels, comm->p2pnChannels, comm->p2pnChannelsPerPeer); return ncclSuccess; } diff --git a/src/graph/search.cc b/src/graph/search.cc index 5a431c17ed..a8f840ce80 100644 --- a/src/graph/search.cc +++ b/src/graph/search.cc @@ -896,7 +896,6 @@ ncclResult_t ncclTopoCompute(ncclTopoSystem* system, struct ncclTopoGraph* graph } if (ngpus == 1) if (graph->pattern != NCCL_TOPO_PATTERN_RING) graph->pattern = NCCL_TOPO_PATTERN_TREE; - // SPLIT_TREE works better on older archs. int ccMin; NCCLCHECK(ncclTopoGetCompCap(system, &ccMin, NULL)); @@ -1177,6 +1176,7 @@ ncclResult_t ncclTopoGetIntraNetDev(struct ncclTopoSystem* system, int rank, str ncclResult_t ncclTopoGetLinkType(struct ncclTopoSystem* system, int cudaDev1, int cudaDev2, bool* isXGMI, int maxInter, int nInter, int *inter) { int interGpus[MAX_XGMI_INTER_GPUS+1]; int ngpus = system->nodes[GPU].count; + *isXGMI = false; // check for direct XGMI connection for (int i=0; inodes[GPU].nodes[i].gpu.dev == cudaDev1) { @@ -1231,6 +1231,5 @@ ncclResult_t ncclTopoGetLinkType(struct ncclTopoSystem* system, int cudaDev1, in } } } - *isXGMI = false; return ncclSuccess; } diff --git a/src/graph/topo.cc b/src/graph/topo.cc index 16a91b8c6f..d41293f7dc 100644 --- a/src/graph/topo.cc +++ b/src/graph/topo.cc @@ -917,6 +917,6 @@ ncclResult_t ncclTopoGetLocalRank(struct ncclTopoSystem* system, int rank, int* } } } - WARN("Could not find local GPU with rank %d\n", rank); + WARN("Could not find local GPU with rank %d", rank); return ncclInternalError; } diff --git a/src/graph/topo.h b/src/graph/topo.h index 5425e98e54..730e8faa52 100644 --- a/src/graph/topo.h +++ b/src/graph/topo.h @@ -225,6 +225,6 @@ static float ncclTopoXGMISpeed(int gcn) { } #define ncclGetKernelIndex(p_comm) \ - (((p_comm)->topo->ll128Enabled ? 1 : 0)*2 + ((p_comm)->collTraceThread ? 1 : 0)) + ((p_comm)->collTraceThread ? 1 : 0) #endif diff --git a/src/graph/tuning.cc b/src/graph/tuning.cc index 2112279bb5..1a6d04c772 100644 --- a/src/graph/tuning.cc +++ b/src/graph/tuning.cc @@ -54,7 +54,7 @@ ncclResult_t parseList(const char* str, const char* elems[], int nelems, int* li // Latencies in us, Bandwidths in GB/s // Tree { LL, LL128, Simple } , Ring { LL, LL128, Simple } -static const float baseLat [NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS] = { { 12.0, 12.0, 17.0 }, { 12.0, 12.0, 17.0 }, { 12.0, 12.0, 17.0 } }; +static const float baseLat [NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS] = { { 12.0, 12.0, 17.0 }, { 12.0, 12.0, 17.0 }, { 12.0, 12.0, 17.0 }, { 12.0, 12.0, 17.0 } }; // NVLink, PCI, Network #define NCCL_HW_NVLINK 0 @@ -71,18 +71,18 @@ struct tuningModel { static struct tuningModel tuning_model_0 { .hwLat = { /* NVLINK */ - { /* Tree (LL/LL128/Simple)*/ { 0.8, 1.4, 2.5 }, /* Ring (LL/LL128/Simple)*/ { 0.8, 2.2, 3.6 }, /* CollNetDirect (Simple)*/ { 0.0, 0.0, 0.8 }, /* CollNetChain (Simple)*/ { 0.0, 0.0, 1.4 } }, + { /* Tree (LL/LL128/Simple)*/ { 0.8, 1.4, 2.5 }, /* Ring (LL/LL128/Simple)*/ { 0.8, 2.2, 3.6 }, /* CollNetDirect (Simple)*/ { 0.0, 0.0, 0.8 }, /* CollNetChain (Simple)*/ { 0.0, 0.0, 1.4 }, /* NVLS */ { 0, 0, 0 } }, /* PCI */ - { /* Tree (LL/LL128/Simple)*/ { 2.2, 2.2, 5.7 }, /* Ring (LL/LL128/Simple)*/ { 2.2, 2.2, 5.7 }, /* CollNetDirect (Simple)*/ { 0.0, 0.0, 5.7 }, /* CollNetChain (Simple)*/ { 0.0, 0.0, 5.7 } }, + { /* Tree (LL/LL128/Simple)*/ { 2.2, 2.2, 5.7 }, /* Ring (LL/LL128/Simple)*/ { 2.2, 2.2, 5.7 }, /* CollNetDirect (Simple)*/ { 0.0, 0.0, 5.7 }, /* CollNetChain (Simple)*/ { 0.0, 0.0, 5.7 }, /* NVLS */ { 0, 0, 0 } }, /* NET */ - { /* Tree (LL/LL128/Simple)*/ { 11.8, 18.2, 20.8 }, /* Ring (LL/LL128/Simple)*/ { 9.5, 19.8, 15.1 }, /* CollNetDirect (Simple)*/ { 0.0, 0.0, 11.8 }, /* CollNetChain (Simple)*/ { 0.0, 0.0, 18.2 } }, + { /* Tree (LL/LL128/Simple)*/ { 11.8, 18.2, 20.8 }, /* Ring (LL/LL128/Simple)*/ { 9.5, 19.8, 15.1 }, /* CollNetDirect (Simple)*/ { 0.0, 0.0, 11.8 }, /* CollNetChain (Simple)*/ { 0.0, 0.0, 18.2 }, /* NVLS */ { 0, 0, 0 } }, }, .bwRatio = { /* 2 nodes */ - { /* Tree (LL/LL128/Simple)*/ { 0.04, 0.22, 0.91 }, /* Ring (LL/LL128/Simple)*/ { 0.04, 0.34, 1.00 }, /* CollNetDirect (Simple)*/ { 0.00, 0.00, 1.00 }, /* CollNetChain (Simple)*/ { 0.00, 0.00, 1.00 } }, + { /* Tree (LL/LL128/Simple)*/ { 0.04, 0.22, 0.91 }, /* Ring (LL/LL128/Simple)*/ { 0.04, 0.34, 1.00 }, /* CollNetDirect (Simple)*/ { 0.00, 0.00, 1.00 }, /* CollNetChain (Simple)*/ { 0.00, 0.00, 1.00 }, /* NVLS */ { 0, 0, 0 } }, /* more than 2 nodes */ - { /* Tree (LL/LL128/Simple)*/ { 0.04, 0.22, 0.95 }, /* Ring (LL/LL128/Simple)*/ { 0.04, 0.34, 1.00 }, /* CollNetDirect (Simple)*/ { 0.00, 0.00, 1.00 }, /* CollNetChain (Simple)*/ { 0.00, 0.00, 1.00 } }, + { /* Tree (LL/LL128/Simple)*/ { 0.04, 0.22, 0.95 }, /* Ring (LL/LL128/Simple)*/ { 0.04, 0.34, 1.00 }, /* CollNetDirect (Simple)*/ { 0.00, 0.00, 1.00 }, /* CollNetChain (Simple)*/ { 0.00, 0.00, 1.00 }, /* NVLS */ { 0, 0, 0 } }, }, .treeCorrectionFactor = { @@ -101,18 +101,18 @@ static struct tuningModel tuning_model_0 { static struct tuningModel tuning_model_1 { .hwLat = { /* NVLINK */ - { /* Tree (LL/LL128/Simple)*/ { 1.5, 1.5, 4.5 }, /* Ring (LL/LL128/Simple)*/ { 1.5, 1.5, 4.5 }, /* CollNetDirect (Simple)*/ { 0.0, 0.0, 4.5 }, /* CollNetChain (Simple)*/ { 0.0, 0.0, 4.5 } }, + { /* Tree (LL/LL128/Simple)*/ { 1.5, 1.5, 4.5 }, /* Ring (LL/LL128/Simple)*/ { 1.5, 1.5, 4.5 }, /* CollNetDirect (Simple)*/ { 0.0, 0.0, 4.5 }, /* CollNetChain (Simple)*/ { 0.0, 0.0, 4.5 }, /* NVLS */ { 0, 0, 0 } }, /* PCI */ - { /* Tree (LL/LL128/Simple)*/ { 2.2, 2.2, 5.7 }, /* Ring (LL/LL128/Simple)*/ { 2.2, 2.2, 5.7 }, /* CollNetDirect (Simple)*/ { 0.0, 0.0, 5.7 }, /* CollNetChain (Simple)*/ { 0.0, 0.0, 5.7 } }, + { /* Tree (LL/LL128/Simple)*/ { 2.2, 2.2, 5.7 }, /* Ring (LL/LL128/Simple)*/ { 2.2, 2.2, 5.7 }, /* CollNetDirect (Simple)*/ { 0.0, 0.0, 5.7 }, /* CollNetChain (Simple)*/ { 0.0, 0.0, 5.7 }, /* NVLS */ { 0, 0, 0 } }, /* NET */ - { /* Tree (LL/LL128/Simple)*/ { 33.0, 33.0, 15.8 }, /* Ring (LL/LL128/Simple)*/ { 5.1, 5.1, 68.8 }, /* CollNetDirect (Simple)*/ { 0.0, 0.0, 15.8 }, /* CollNetChain (Simple)*/ { 0.0, 0.0, 15.8 } }, + { /* Tree (LL/LL128/Simple)*/ { 33.0, 33.0, 15.8 }, /* Ring (LL/LL128/Simple)*/ { 5.1, 5.1, 68.8 }, /* CollNetDirect (Simple)*/ { 0.0, 0.0, 15.8 }, /* CollNetChain (Simple)*/ { 0.0, 0.0, 15.8 }, /* NVLS */ { 0, 0, 0 } }, }, .bwRatio = { /* 2 nodes */ - { /* Tree (LL/LL128/Simple)*/ { 0.12, 1.00, 0.99 }, /* Ring (LL/LL128/Simple)*/ { 0.12, 1.00, 1.00 }, /* CollNetDirect (Simple)*/ { 0.00, 0.00, 1.00 }, /* CollNetChain (Simple)*/ { 0.00, 0.00, 1.00 } }, + { /* Tree (LL/LL128/Simple)*/ { 0.12, 1.00, 0.99 }, /* Ring (LL/LL128/Simple)*/ { 0.12, 1.00, 1.00 }, /* CollNetDirect (Simple)*/ { 0.00, 0.00, 1.00 }, /* CollNetChain (Simple)*/ { 0.00, 0.00, 1.00 }, /* NVLS */ { 0, 0, 0 } }, /* more than 2 nodes */ - { /* Tree (LL/LL128/Simple)*/ { 0.15, 1.00, 0.42 }, /* Ring (LL/LL128/Simple)*/ { 0.20, 1.00, 1.00 }, /* CollNetDirect (Simple)*/ { 0.00, 0.00, 1.00 }, /* CollNetChain (Simple)*/ { 0.00, 0.00, 1.00 } }, + { /* Tree (LL/LL128/Simple)*/ { 0.15, 1.00, 0.42 }, /* Ring (LL/LL128/Simple)*/ { 0.20, 1.00, 1.00 }, /* CollNetDirect (Simple)*/ { 0.00, 0.00, 1.00 }, /* CollNetChain (Simple)*/ { 0.00, 0.00, 1.00 }, /* NVLS */ { 0, 0, 0 } }, }, .treeCorrectionFactor = { @@ -131,18 +131,18 @@ static struct tuningModel tuning_model_1 { static struct tuningModel tuning_model_2 { .hwLat = { /* NVLINK */ - { /* Tree (LL/LL128/Simple)*/ { 1.5, 1.5, 4.5 }, /* Ring (LL/LL128/Simple)*/ { 1.5, 1.5, 4.5 }, /* CollNetDirect (Simple)*/ { 0.0, 0.0, 4.5 }, /* CollNetChain (Simple)*/ { 0.0, 0.0, 4.5 } }, + { /* Tree (LL/LL128/Simple)*/ { 1.5, 1.5, 4.5 }, /* Ring (LL/LL128/Simple)*/ { 1.5, 1.5, 4.5 }, /* CollNetDirect (Simple)*/ { 0.0, 0.0, 4.5 }, /* CollNetChain (Simple)*/ { 0.0, 0.0, 4.5 }, /* NVLS */ { 0, 0, 0 } }, /* PCI */ - { /* Tree (LL/LL128/Simple)*/ { 2.2, 2.2, 5.7 }, /* Ring (LL/LL128/Simple)*/ { 2.2, 2.2, 5.7 }, /* CollNetDirect (Simple)*/ { 0.0, 0.0, 5.7 }, /* CollNetChain (Simple)*/ { 0.0, 0.0, 5.7 } }, + { /* Tree (LL/LL128/Simple)*/ { 2.2, 2.2, 5.7 }, /* Ring (LL/LL128/Simple)*/ { 2.2, 2.2, 5.7 }, /* CollNetDirect (Simple)*/ { 0.0, 0.0, 5.7 }, /* CollNetChain (Simple)*/ { 0.0, 0.0, 5.7 }, /* NVLS */ { 0, 0, 0 } }, /* NET */ - { /* Tree (LL/LL128/Simple)*/ { 27.9, 27.9, 15.8 }, /* Ring (LL/LL128/Simple)*/ { 12.1, 12.1, 68.8 }, /* CollNetDirect (Simple)*/ { 0.0, 0.0, 15.8 }, /* CollNetChain (Simple)*/ { 0.0, 0.0, 15.8 } }, + { /* Tree (LL/LL128/Simple)*/ { 27.9, 27.9, 15.8 }, /* Ring (LL/LL128/Simple)*/ { 12.1, 12.1, 68.8 }, /* CollNetDirect (Simple)*/ { 0.0, 0.0, 15.8 }, /* CollNetChain (Simple)*/ { 0.0, 0.0, 15.8 }, /* NVLS */ { 0, 0, 0 } }, }, .bwRatio = { /* 2 nodes */ - { /* Tree (LL/LL128/Simple)*/ { 0.07, 1.00, 0.99 }, /* Ring (LL/LL128/Simple)*/ { 0.08, 1.00, 1.00 }, /* CollNetDirect (Simple)*/ { 0.00, 0.00, 1.00 }, /* CollNetChain (Simple)*/ { 0.00, 0.00, 1.00 } }, + { /* Tree (LL/LL128/Simple)*/ { 0.07, 1.00, 0.99 }, /* Ring (LL/LL128/Simple)*/ { 0.08, 1.00, 1.00 }, /* CollNetDirect (Simple)*/ { 0.00, 0.00, 1.00 }, /* CollNetChain (Simple)*/ { 0.00, 0.00, 1.00 }, /* NVLS */ { 0, 0, 0 } }, /* more than 2 nodes */ - { /* Tree (LL/LL128/Simple)*/ { 0.07, 1.00, 0.42 }, /* Ring (LL/LL128/Simple)*/ { 0.08, 1.00, 1.00 }, /* CollNetDirect (Simple)*/ { 0.00, 0.00, 1.00 }, /* CollNetChain (Simple)*/ { 0.00, 0.00, 1.00 } }, + { /* Tree (LL/LL128/Simple)*/ { 0.07, 1.00, 0.42 }, /* Ring (LL/LL128/Simple)*/ { 0.08, 1.00, 1.00 }, /* CollNetDirect (Simple)*/ { 0.00, 0.00, 1.00 }, /* CollNetChain (Simple)*/ { 0.00, 0.00, 1.00 }, /* NVLS */ { 0, 0, 0 } }, }, .treeCorrectionFactor = { @@ -161,18 +161,18 @@ static struct tuningModel tuning_model_2 { static struct tuningModel tuning_model_3 { .hwLat = { /* NVLINK */ - { /* Tree (LL/LL128/Simple)*/ { 0.8, 0.0, 2.5 }, /* Ring (LL/LL128/Simple)*/ { 0.8, 0.0, 3.6 }, /* CollNetDirect (Simple)*/ { 0.0, 0.0, 0.8 }, /* CollNetChain (Simple)*/ { 0.0, 0.0, 0.0 } }, + { /* Tree (LL/LL128/Simple)*/ { 0.8, 0.0, 2.5 }, /* Ring (LL/LL128/Simple)*/ { 0.8, 0.0, 3.6 }, /* CollNetDirect (Simple)*/ { 0.0, 0.0, 0.8 }, /* CollNetChain (Simple)*/ { 0.0, 0.0, 0.0 }, /* NVLS */ { 0, 0, 0 } }, /* PCI */ - { /* Tree (LL/LL128/Simple)*/ { 2.2, 2.2, 5.7 }, /* Ring (LL/LL128/Simple)*/ { 2.2, 2.2, 5.7 }, /* CollNetDirect (Simple)*/ { 0.0, 0.0, 5.7 }, /* CollNetChain (Simple)*/ { 0.0, 0.0, 5.7 } }, + { /* Tree (LL/LL128/Simple)*/ { 2.2, 2.2, 5.7 }, /* Ring (LL/LL128/Simple)*/ { 2.2, 2.2, 5.7 }, /* CollNetDirect (Simple)*/ { 0.0, 0.0, 5.7 }, /* CollNetChain (Simple)*/ { 0.0, 0.0, 5.7 }, /* NVLS */ { 0, 0, 0 } }, /* NET */ - { /* Tree (LL/LL128/Simple)*/ { 12.5, 0.0, 22.4 }, /* Ring (LL/LL128/Simple)*/ { 9.5, 0.0, 19.8 }, /* CollNetDirect (Simple)*/ { 0.0, 0.0, 12.5 }, /* CollNetChain (Simple)*/ { 0.0, 0.0, 0.0 } }, + { /* Tree (LL/LL128/Simple)*/ { 12.5, 0.0, 22.4 }, /* Ring (LL/LL128/Simple)*/ { 9.5, 0.0, 19.8 }, /* CollNetDirect (Simple)*/ { 0.0, 0.0, 12.5 }, /* CollNetChain (Simple)*/ { 0.0, 0.0, 0.0 }, /* NVLS */ { 0, 0, 0 } }, }, .bwRatio = { /* 2 nodes */ - { /* Tree (LL/LL128/Simple)*/ { 0.20, 0.00, 1.75 }, /* Ring (LL/LL128/Simple)*/ { 0.20, 0.00, 1.00 }, /* CollNetDirect (Simple)*/ { 0.00, 0.00, 1.00 }, /* CollNetChain (Simple)*/ { 0.00, 0.00, 1.00 } }, + { /* Tree (LL/LL128/Simple)*/ { 0.20, 0.00, 1.75 }, /* Ring (LL/LL128/Simple)*/ { 0.20, 0.00, 1.00 }, /* CollNetDirect (Simple)*/ { 0.00, 0.00, 1.00 }, /* CollNetChain (Simple)*/ { 0.00, 0.00, 1.00 }, /* NVLS */ { 0, 0, 0 } }, /* more than 2 nodes */ - { /* Tree (LL/LL128/Simple)*/ { 0.20, 0.00, 0.96 }, /* Ring (LL/LL128/Simple)*/ { 0.20, 0.00, 1.00 }, /* CollNetDirect (Simple)*/ { 0.00, 0.00, 1.00 }, /* CollNetChain (Simple)*/ { 0.00, 0.00, 1.00 } }, + { /* Tree (LL/LL128/Simple)*/ { 0.20, 0.00, 0.96 }, /* Ring (LL/LL128/Simple)*/ { 0.20, 0.00, 1.00 }, /* CollNetDirect (Simple)*/ { 0.00, 0.00, 1.00 }, /* CollNetChain (Simple)*/ { 0.00, 0.00, 1.00 }, /* NVLS */ { 0, 0, 0 } }, }, .treeCorrectionFactor = { @@ -193,16 +193,16 @@ static struct tuningModel tuning_model_4 { /* NVLINK */ { /* Tree (LL/LL128/Simple)*/ { 0.8, 1.4, 2.5 }, /* Ring (LL/LL128/Simple)*/ { 0.8, 2.2, 3.6 }, /* CollNetDirect (Simple)*/ { 0.8, 1.4, 2.5 }, /* CollNetChain (Simple)*/ { 0.8, 1.4, 2.5 } }, /* PCI */ - { /* Tree (LL/LL128/Simple)*/ { 2.2, 2.2, 5.7 }, /* Ring (LL/LL128/Simple)*/ { 2.2, 2.2, 5.7 }, /* CollNetDirect (Simple)*/ { 0.0, 0.0, 5.7 }, /* CollNetChain (Simple)*/ { 0.0, 0.0, 5.7 } }, + { /* Tree (LL/LL128/Simple)*/ { 2.2, 2.2, 5.7 }, /* Ring (LL/LL128/Simple)*/ { 2.2, 2.2, 5.7 }, /* CollNetDirect (Simple)*/ { 0.0, 0.0, 5.7 }, /* CollNetChain (Simple)*/ { 0.0, 0.0, 5.7 }, /* NVLS */ { 0, 0, 0 } }, /* NET */ { /* Tree (LL/LL128/Simple)*/ { 32.2, 34.4, 47.6 }, /* Ring (LL/LL128/Simple)*/ { 35.4, 87.8, 209.2 }, /* CollNetDirect (Simple)*/ { 0.0, 0.0, 47.6 }, /* CollNetChain (Simple)*/ { 0.0, 0.0, 47.6 } }, }, .bwRatio = { /* 2 nodes */ - { /* Tree (LL/LL128/Simple)*/ { 0.16, 1.09, 1.61 }, /* Ring (LL/LL128/Simple)*/ { 0.15, 0.41, 1.00 }, /* CollNetDirect (Simple)*/ { 0.00, 0.00, 1.00 }, /* CollNetChain (Simple)*/ { 0.00, 0.00, 1.00 } }, + { /* Tree (LL/LL128/Simple)*/ { 0.16, 1.09, 1.61 }, /* Ring (LL/LL128/Simple)*/ { 0.15, 0.41, 1.00 }, /* CollNetDirect (Simple)*/ { 0.00, 0.00, 1.00 }, /* CollNetChain (Simple)*/ { 0.00, 0.00, 1.00 }, /* NVLS */ { 0, 0, 0 } }, /* more than 2 nodes */ - { /* Tree (LL/LL128/Simple)*/ { 0.16, 1.09, 1.08 }, /* Ring (LL/LL128/Simple)*/ { 0.15, 0.41, 1.00 }, /* CollNetDirect (Simple)*/ { 0.00, 0.00, 1.00 }, /* CollNetChain (Simple)*/ { 0.00, 0.00, 1.00 } }, + { /* Tree (LL/LL128/Simple)*/ { 0.16, 1.09, 1.08 }, /* Ring (LL/LL128/Simple)*/ { 0.15, 0.41, 1.00 }, /* CollNetDirect (Simple)*/ { 0.00, 0.00, 1.00 }, /* CollNetChain (Simple)*/ { 0.00, 0.00, 1.00 }, /* NVLS */ { 0, 0, 0 } }, }, .treeCorrectionFactor = { @@ -232,7 +232,7 @@ static struct tuningModel rcclTuningModel[] = { #define HOPPER_COMPCAP_IDX 2 // LL128 max BW per channel -static const double ll128MaxBwPerCh = 20.0; +static const double ll128MaxBwPerCh[3] = { 20.0, 20.0, 36.7 }; static const double llMaxBws[3][3] = { /* Volta-N1/Intel-N2/Intel-N4) */ {39.0, 39.0, 20.4}, /* Ampere-N1/AMD-N2/AMD-N4) */ {87.7, 22.5 /*avg of ring & tree*/, 19.0}, @@ -242,7 +242,7 @@ static const double llMaxBws[3][3] = { static const double perChMaxTreeBws[3][3] = { /* Volta (N1/N2/N4) */ {26.5, 18.5, 10.0}, /* Ampere (N1/N2/N4) */ {24.0, 23.6, 17.8}, - /* Hopper (N1/N2/N4) */ {24.0, 23.6, 17.8}, + /* Hopper (N1/N2/N4) */ {38.7, 41.4, 33.0}, }; ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCompCap, struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph, struct ncclTopoGraph* collNetGraph) { @@ -261,7 +261,8 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom comm->maxThreads[NCCL_ALGO_TREE][NCCL_PROTO_SIMPLE] = getNthreads("NCCL_NTHREADS", ncclParamNthreads(), 2*WARP_SIZE, NCCL_SIMPLE_MAX_NTHREADS, NCCL_SIMPLE_MAX_NTHREADS); comm->maxThreads[NCCL_ALGO_COLLNET_DIRECT][NCCL_PROTO_SIMPLE] = - comm->maxThreads[NCCL_ALGO_COLLNET_CHAIN][NCCL_PROTO_SIMPLE] = NCCL_SIMPLE_MAX_NTHREADS; + comm->maxThreads[NCCL_ALGO_COLLNET_CHAIN][NCCL_PROTO_SIMPLE] = + comm->maxThreads[NCCL_ALGO_NVLS][NCCL_PROTO_SIMPLE] = NCCL_SIMPLE_MAX_NTHREADS; comm->maxThreads[NCCL_ALGO_RING][NCCL_PROTO_LL] = comm->maxThreads[NCCL_ALGO_TREE][NCCL_PROTO_LL] = getNthreads("NCCL_NTHREADS", ncclParamNthreads(), 2*WARP_SIZE, NCCL_LL_MAX_NTHREADS, NCCL_LL_MAX_NTHREADS); comm->maxThreads[NCCL_ALGO_RING][NCCL_PROTO_LL128] = comm->maxThreads[NCCL_ALGO_TREE][NCCL_PROTO_LL128] = @@ -272,7 +273,7 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom int nRanks = comm->nRanks; if (nRanks <= 1) return ncclSuccess; - int compCapIndex = (minCompCap == 80 && maxCompCap == 80) ? AMPERE_COMPCAP_IDX : ((minCompCap == 90 && maxCompCap == 90) ? HOPPER_COMPCAP_IDX : VOLTA_COMPCAP_IDX); + int compCapIndex = minCompCap >= 90 ? HOPPER_COMPCAP_IDX : minCompCap >= 80 ? AMPERE_COMPCAP_IDX : VOLTA_COMPCAP_IDX; int cpuArch, cpuVendor, cpuModel; NCCLCHECK(ncclTopoCpuType(comm->topo, &cpuArch, &cpuVendor, &cpuModel)); int index2 = nNodes <= 2 ? nNodes-1 : 2; @@ -284,7 +285,7 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom //if (cpuArch == NCCL_TOPO_CPU_ARCH_POWER) hwLat[NCCL_HW_PCI][NCCL_ALGO_TREE][NCCL_PROTO_SIMPLE] = hwLat[NCCL_HW_PCI][NCCL_ALGO_RING][NCCL_PROTO_SIMPLE]; float ppn = (float)nRanks / nNodes; // if ppn < 2, then we are sending/receiving at the same GPU through the NIC, apply some bw discount - struct ncclTopoGraph* graphs[NCCL_NUM_ALGORITHMS] = { treeGraph, ringGraph, collNetGraph, collNetGraph }; + struct ncclTopoGraph* graphs[NCCL_NUM_ALGORITHMS] = { treeGraph, ringGraph, collNetGraph, collNetGraph, ringGraph/* we only need the NVSwitch speed for NVLS*/ }; int intraHw[NCCL_NUM_ALGORITHMS], hw[NCCL_NUM_ALGORITHMS]; for (int a=0; atypeIntra == LINK_NVL ? NCCL_HW_NVLINK : NCCL_HW_PCI; for (int a=0; abwIntra : graphs[a]->bwInter; float busBw = comm->topo->baseBw != 0.0 ? comm->topo->baseBw : graphs[a]->nChannels * bw; @@ -314,11 +316,12 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom busBw *= rcclTuningModel[comm->topo->tuning].bwRatio[1][a][p]; #else if (compCapIndex == AMPERE_COMPCAP_IDX) busBw = std::min(busBw, 235.0f); + if (compCapIndex == HOPPER_COMPCAP_IDX) busBw = std::min(busBw, 370.0f); if (a == NCCL_ALGO_RING && p == NCCL_PROTO_LL) { busBw = std::min(llMaxBw, busBw * ((nNodes > 1 || coll == ncclFuncAllReduce || coll == ncclFuncReduce) ? 1.0/4.0 : 1.0/3.0)); } - if (a == NCCL_ALGO_RING && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (ppn < 2 ? 0.7 : 0.92 /*120.0/128.0*/), ll128MaxBwPerCh*graphs[a]->nChannels); + if (a == NCCL_ALGO_RING && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (ppn < 2 ? 0.7 : 0.92 /*120.0/128.0*/), ll128MaxBwPerCh[compCapIndex]*graphs[a]->nChannels); if (a == NCCL_ALGO_TREE) busBw = std::min(busBw*.92, graphs[a]->nChannels*perChMaxTreeBw); if (a == NCCL_ALGO_TREE && p == NCCL_PROTO_LL) busBw = std::min(busBw*1.0/3.8, llMaxBw); - if (a == NCCL_ALGO_TREE && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (nNodes == 1 ? 7.0/9.0 : 120.0/128.0), ll128MaxBwPerCh*graphs[a]->nChannels); + if (a == NCCL_ALGO_TREE && p == NCCL_PROTO_LL128) busBw = std::min(busBw * (nNodes == 1 ? 7.0/9.0 : 120.0/128.0), ll128MaxBwPerCh[compCapIndex]*graphs[a]->nChannels); if (a == NCCL_ALGO_COLLNET_DIRECT && p != NCCL_PROTO_SIMPLE) busBw = 0; // Not used if (a == NCCL_ALGO_COLLNET_CHAIN && p != NCCL_PROTO_SIMPLE) busBw = 0; // Not used if (a == NCCL_ALGO_COLLNET_DIRECT && p == NCCL_PROTO_SIMPLE) { @@ -331,7 +334,10 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom if (a == NCCL_ALGO_COLLNET_CHAIN && p == NCCL_PROTO_SIMPLE) busBw *= .75; // Convert bus BW to algorithm BW - float ratio = (a != NCCL_ALGO_RING) ? .5 : (1.0 * nRanks) / nsteps; + float ratio; + if (a == NCCL_ALGO_RING) ratio = (1.0 * nRanks) / nsteps; + else if (a == NCCL_ALGO_NVLS) ratio = .75; + else ratio = .5; comm->bandwidths[coll][a][p] = busBw * ratio; comm->latencies[coll][a][p] = baseLat[a][p]; @@ -366,7 +372,7 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom // Protocols/Algorithms enable/disable, and user overrides. // All are enabled except ll128 which is enabled by default only in certain cases. int protoEnable[NCCL_NUM_PROTOCOLS] = { 1, 2, 1 }; - int algoEnable[NCCL_NUM_ALGORITHMS] = { 1, 1, 1, 1 }; + int algoEnable[NCCL_NUM_ALGORITHMS] = { 1, 1, 1, 1, 1 }; const char *protoStr = getenv("NCCL_PROTO"); if (protoStr) { @@ -378,6 +384,10 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom INFO(NCCL_ENV, "NCCL_ALGO set by environment to %s", algoStr); NCCLCHECK(parseList(algoStr, ncclAlgoStr, NCCL_NUM_ALGORITHMS, algoEnable)); } + + // Disable NVLink SHARP if not supported + if (comm->nvlsSupport == 0 /* || comm->localRanks <= 2*/) algoEnable[NCCL_ALGO_NVLS] = 0; + // Disable CollNet if it is not supported if (comm->collNetSupport == 0) { algoEnable[NCCL_ALGO_COLLNET_DIRECT] = 0; @@ -404,7 +414,7 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom #else // Enable LL128 by default only on Volta/Ampere/Hopper+NVLink. Other cases are not tested and may cause silent data corruption. pEnable = 1; - pEnable &= (graphs[a]->typeInter <= PATH_PXB); + pEnable &= (graphs[a]->typeInter <= PATH_PXB || (minCompCap >= 90 && graphs[a]->typeInter <= PATH_PXN)); pEnable &= (graphs[a]->typeIntra <= PATH_NVL); pEnable &= (minCompCap == maxCompCap); switch (minCompCap) { @@ -416,8 +426,9 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom #endif } if (pEnable == 0) comm->bandwidths[c][a][p] = 0; - // Only disable algo for Allreduce since others only have one - if (c == ncclFuncAllReduce && algoEnable[a] == 0) comm->bandwidths[c][a][p] = 0; + // Never disable ring for non-allreduce operations. That allows to run real apps with NCCL_ALGO=TREE. + if (a == NCCL_ALGO_RING && c != ncclFuncAllReduce) continue; + if (algoEnable[a] == 0) comm->bandwidths[c][a][p] = 0; } if (comm->rank == 0) { @@ -461,9 +472,9 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom char* str = getenv("NCCL_THREAD_THRESHOLDS"); if (str) { INFO(NCCL_ENV, "NCCL_THREAD_THRESHOLDS set by environment to %s", str); - ssize_t t[NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS] = {{ -2, -2, -2 }, { -2, -2, -2 }, { -2, -2, -2 }, { -2, -2, -2 }}; + ssize_t t[2][NCCL_NUM_PROTOCOLS] = {{ -2, -2, -2 }, { -2, -2, -2 }}; sscanf(str, "%ld %ld %ld %ld %ld %ld", t[0], t[0]+1, t[0]+2, t[1], t[1]+1, t[1]+2); - for (int a=0; a= 0) comm->threadThresholds[a][p] = t[a][p]; } diff --git a/src/group.cc b/src/group.cc index e56354d686..477b34ed32 100644 --- a/src/group.cc +++ b/src/group.cc @@ -328,7 +328,7 @@ static ncclResult_t groupLaunch(struct ncclAsyncJob *job_) { ret = ncclSystemError; } job->state = ncclGroupJobJoined; - if (job->result != ncclSuccess) { + if (job->result != ncclSuccess && ret == ncclSuccess) { ret = job->result; errorJobAbortFlag = true; } @@ -339,7 +339,6 @@ static ncclResult_t groupLaunch(struct ncclAsyncJob *job_) { if (*groupAbortFlag == true || errorJobAbortFlag == true) { *job->abortFlag = 1; - ret = ncclInternalError; } job = job->next; diff --git a/src/include/bootstrap.h b/src/include/bootstrap.h index e70db041df..2ecea7a94f 100644 --- a/src/include/bootstrap.h +++ b/src/include/bootstrap.h @@ -25,6 +25,7 @@ ncclResult_t bootstrapSend(void* commState, int peer, int tag, void* data, int s ncclResult_t bootstrapRecv(void* commState, int peer, int tag, void* data, int size); ncclResult_t bootstrapBarrier(void* commState, int *ranks, int rank, int nranks, int tag); ncclResult_t bootstrapIntraNodeAllGather(void* commState, int *ranks, int rank, int nranks, void* allData, int size); +ncclResult_t bootstrapIntraNodeBroadcast(void* commState, int *ranks, int rank, int nranks, int root, void* bcastData, int size); ncclResult_t bootstrapClose(void* commState); ncclResult_t bootstrapAbort(void* commState); #endif diff --git a/src/include/collectives.h b/src/include/collectives.h index fbad473a70..0fb2badb66 100644 --- a/src/include/collectives.h +++ b/src/include/collectives.h @@ -35,12 +35,6 @@ struct ncclDevRedOpFull { #define NCCL_KERN_NAME_DEBUG(func, algo, proto, devredop, type) \ ncclKernelDebug_##func##_##algo##_##proto##_##devredop##_##type -#define NCCL_KERN_NAME_LL128(func, algo, proto, devredop, type) \ - ncclKernelLL128_##func##_##algo##_##proto##_##devredop##_##type - -#define NCCL_KERN_NAME_LL128_DEBUG(func, algo, proto, devredop, type) \ - ncclKernelLL128Debug_##func##_##algo##_##proto##_##devredop##_##type - #define NCCL_IMPL_NAME(func, algo, proto) \ nccl##func##algo##proto @@ -49,16 +43,12 @@ struct ncclDevRedOpFull { #define DECL5(func, algo, proto, devredop, type) \ extern __device__ void NCCL_FUNC_NAME(func, algo, proto, devredop, type)(); \ extern __global__ void NCCL_KERN_NAME(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead); \ - extern __global__ void NCCL_KERN_NAME_DEBUG(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead); \ - extern __global__ void NCCL_KERN_NAME_LL128(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead); \ - extern __global__ void NCCL_KERN_NAME_LL128_DEBUG(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead); + extern __global__ void NCCL_KERN_NAME_DEBUG(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead); #else #define DECL5(func, algo, proto, devredop, type) \ extern __device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, devredop, type)(); \ extern __global__ void NCCL_KERN_NAME(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead); \ - extern __global__ void NCCL_KERN_NAME_DEBUG(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead); \ - extern __global__ void NCCL_KERN_NAME_LL128(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead); \ - extern __global__ void NCCL_KERN_NAME_LL128_DEBUG(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead); + extern __global__ void NCCL_KERN_NAME_DEBUG(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead); #endif #define SINGLE_ARG(...) __VA_ARGS__ @@ -76,7 +66,8 @@ struct ncclDevRedOpFull { DECL4(func, RING, devredop, type, undef) \ DECL4(func, TREE, devredop, type, undef) \ DECL4(func, COLLNET_DIRECT, devredop, type, undef) \ - DECL4(func, COLLNET_CHAIN, devredop, type, undef) + DECL4(func, COLLNET_CHAIN, devredop, type, undef) \ + DECL4(func, NVLS, devredop, type, undef) #if defined(RCCL_BFLOAT16) #define DECL2(func, devredop, undefForFloat) \ @@ -147,4 +138,13 @@ extern __device__ void NCCL_ONERANK_REDUCE_NAME(PreMulSum, double)(); #define ALLTOALL_PIVOT_SLICESTEPS 2 #define ALLTOALL_PIVOT_CHUNKSTEPS 4 +// We can't use the enum identifiers like ncclSum, ncclFloat, etc since this +// macro will be used in preprocessor conditionals where enums have no meaning. +#define NCCL_NVLS_SUPPORTS(/*ncclDataType_t*/ type, /*ncclDevRedOp_t*/ red) \ + (((type==2 || type==3) && (red==0 || red==2 || red==3)) || \ + ((type==4 || type==5) && (red==0 || red==2 || red==3)) || \ + ((type==6 || type==9) && (red==0 || red==2 || red==3)) || \ + (type==7 && red==0) || \ + (type==8 && red==0)) + #endif diff --git a/src/include/comm.h b/src/include/comm.h index 07be8ac6f6..dac5cc8f53 100644 --- a/src/include/comm.h +++ b/src/include/comm.h @@ -110,6 +110,7 @@ struct ncclChannel { struct ncclTree collnetChain; struct ncclDirect collnetDirect; struct ncclTree binTree; + struct ncclNvls nvls; int id; // index of this channel uint32_t workFifoSent; // last used work index+1 uint64_t p2pOpCount; @@ -183,10 +184,12 @@ struct ncclComm { int nRanks; // number of GPUs in communicator int cudaDev; // my cuda device index int compCap; // compute capability of the GPU + int minCompCap; // min compute capability in the communicator int64_t busId; // my PCI bus ID in int format cpu_set_t cpuAffinity; // CPU affinity of the GPU int WarpSize; int virtualId; + int cudaArch; // matches __CUDA_ARCH__ of device int node; int nNodes; @@ -209,6 +212,7 @@ struct ncclComm { // Channels for collectives int nChannels; + int nvlsChannels; // Channels (per peer) for p2p int p2pnChannels; int p2pnChannelsPerPeer; @@ -270,6 +274,10 @@ struct ncclComm { int collNetSupport; int intraHighestTransportType; + // NVLink SHARP (NVLS) support + int nvlsSupport; + void* nvlsResources; + size_t channelSize; // User requested work size (bytes) for channel partitions // Internal streams @@ -313,6 +321,11 @@ struct ncclComm { // communicator mode int blocking; + // CGA cluster size + int cgaClusterSize; + int minCTAs, maxCTAs; + // network interface name + char *netName; // initState is to more conveniently reclaim resources when errors happen. ncclResult_t initState; // flag to indicate if ncclCommFinalize() is called diff --git a/src/include/cudawrap.h b/src/include/cudawrap.h index 0fd594582a..317ca2df6d 100644 --- a/src/include/cudawrap.h +++ b/src/include/cudawrap.h @@ -73,10 +73,32 @@ DECLARE_CUDA_PFN_EXTERN(cuGetErrorName, 6000); DECLARE_CUDA_PFN_EXTERN(cuMemGetAddressRange, 3020); DECLARE_CUDA_PFN_EXTERN(cuCtxCreate, 3020); DECLARE_CUDA_PFN_EXTERN(cuCtxDestroy, 4000); +DECLARE_CUDA_PFN_EXTERN(cuCtxGetCurrent, 4000); DECLARE_CUDA_PFN_EXTERN(cuCtxSetCurrent, 4000); +DECLARE_CUDA_PFN_EXTERN(cuCtxGetDevice, 2000); +// cuMem API support +DECLARE_CUDA_PFN_EXTERN(cuMemAddressReserve, 10020); +DECLARE_CUDA_PFN_EXTERN(cuMemAddressFree, 10020); +DECLARE_CUDA_PFN_EXTERN(cuMemCreate, 10020); +DECLARE_CUDA_PFN_EXTERN(cuMemGetAllocationGranularity, 10020); +DECLARE_CUDA_PFN_EXTERN(cuMemExportToShareableHandle, 10020); +DECLARE_CUDA_PFN_EXTERN(cuMemImportFromShareableHandle, 10020); +DECLARE_CUDA_PFN_EXTERN(cuMemMap, 10020); +DECLARE_CUDA_PFN_EXTERN(cuMemRelease, 10020); +DECLARE_CUDA_PFN_EXTERN(cuMemSetAccess, 10020); +DECLARE_CUDA_PFN_EXTERN(cuMemUnmap, 10020); #if CUDA_VERSION >= 11070 DECLARE_CUDA_PFN_EXTERN(cuMemGetHandleForAddressRange, 11070); // DMA-BUF support #endif +#if CUDA_VERSION >= 12010 +/* NVSwitch Multicast support */ +DECLARE_CUDA_PFN_EXTERN(cuMulticastAddDevice, 12010); +DECLARE_CUDA_PFN_EXTERN(cuMulticastBindMem, 12010); +DECLARE_CUDA_PFN_EXTERN(cuMulticastBindAddr, 12010); +DECLARE_CUDA_PFN_EXTERN(cuMulticastCreate, 12010); +DECLARE_CUDA_PFN_EXTERN(cuMulticastGetGranularity, 12010); +DECLARE_CUDA_PFN_EXTERN(cuMulticastUnbind, 12010); +#endif #endif /* CUDA Driver functions loaded with dlsym() */ @@ -88,6 +110,7 @@ DECLARE_CUDA_PFN_EXTERN(cuGetProcAddress, 11030); ncclResult_t ncclCudaLibraryInit(void); extern int ncclCudaDriverVersionCache; +extern bool ncclCudaLaunchBlocking; // initialized by ncclCudaLibraryInit() inline ncclResult_t ncclCudaDriverVersion(int* driver) { int version = __atomic_load_n(&ncclCudaDriverVersionCache, __ATOMIC_RELAXED); @@ -98,5 +121,4 @@ inline ncclResult_t ncclCudaDriverVersion(int* driver) { *driver = version; return ncclSuccess; } - #endif diff --git a/src/include/devcomm.h b/src/include/devcomm.h index de8f57d65f..4fc208e284 100644 --- a/src/include/devcomm.h +++ b/src/include/devcomm.h @@ -21,11 +21,12 @@ typedef enum { ncclFuncBroadcast, ncclFuncReduce, ncclFuncAllGather, ncclFuncReduceScatter, ncclFuncAllReduce, ncclFuncSendRecv, ncclFuncSend, ncclFuncRecv, ncclFuncAllToAllPivot, ncclNumFuncs} ncclFunc_t; extern const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+2]; -#define NCCL_NUM_ALGORITHMS 4 // Tree/Ring/CollNet* +#define NCCL_NUM_ALGORITHMS 5 // Tree/Ring/CollNet* #define NCCL_ALGO_TREE 0 #define NCCL_ALGO_RING 1 #define NCCL_ALGO_COLLNET_DIRECT 2 #define NCCL_ALGO_COLLNET_CHAIN 3 +#define NCCL_ALGO_NVLS 4 extern const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS]; #define NCCL_NUM_PROTOCOLS 3 // Simple/LL/LL128 @@ -88,6 +89,7 @@ static_assert(NCCL_LL_CLEAN_MASK % NCCL_STEPS == 0, "Invalid NCCL_LL_CLEAN_MASK #define NCCL_DIRECT_NIC 0x04 #define NCCL_IPC_WRITE 0x08 #define NCCL_IPC_READ 0x10 +#define NCCL_NVLS_MIN_POLL 0x20 struct ncclConnInfo { // Regular comm mechanism @@ -95,7 +97,7 @@ struct ncclConnInfo { uint64_t *tail; // Local for recv, remote for send uint64_t *head; // Local for send, remote for recv - int direct; // Direct communication + int flags; // Direct communication / other flags int shared; // Buffers are shared void **ptrExchange; // Pointer exchange for direct communication uint64_t* redOpArgExchange; // PreOp scaler exchange for direct pull case @@ -154,14 +156,23 @@ struct ncclTree { struct ncclDirect { int depth; int out; - int nHeads; - int headRank; - int shift; + int nHeads; // Number of parallel N<->1<->net operations we'll do in parallel; size of up/down + int headRank; // Index in 0..nHeads-1 I am the head rank of. -1 if I'm not a head rank (no local NIC) + int shift; // Shuffling of send/recv for scatter/gather operations, basically localRank%nHeads int up[NCCL_MAX_DIRECT_ARITY]; int down[NCCL_MAX_DIRECT_ARITY]; }; #define NCCL_CONN_IDX_P2P_NET 2 +#define NCCL_MAX_NVLS_ARITY 8 +struct ncclNvls { + int out; + int nHeads; // Number of parallel N<->1<->net operations we'll do in parallel; size of up/down + int headRank; // Index in 0..nHeads-1 I am the head rank of. -1 if I'm not a head rank (no local NIC) + int up[NCCL_MAX_NVLS_ARITY]; + int down; +}; + #define NCCL_MAX_CONNS 3 struct ncclChannelPeer { struct ncclConnector send[NCCL_MAX_CONNS]; @@ -361,6 +372,7 @@ struct alignas(16) ncclDevChannel { struct ncclTree collnetChain; struct ncclDirect collnetDirect; struct ncclTree binTree; + struct ncclNvls nvls; uint32_t* workFifoDone; // Location of done counter, device writes index+1 of last work processed }; @@ -399,4 +411,65 @@ struct alignas(16) ncclDevCommAndChannels { struct ncclDevChannel channels[MAXCHANNELS]; }; +#ifdef __CUDA_ARCH__ + #define NCCL_CUDA_ARCH __CUDA_ARCH__ +#else + #define NCCL_CUDA_ARCH 0 +#endif + +template +__host__ __device__ constexpr T min_constexpr(T a) { return a; } +template +__host__ __device__ constexpr T min_constexpr(T a, T b, Ts ...c) { + return min_constexpr((a < b ? a : b), c...); +} + +template +__host__ __device__ constexpr T max_constexpr(T a) { return a; } +template +__host__ __device__ constexpr T max_constexpr(T a, T b, Ts ...c) { + return max_constexpr((a > b ? a : b), c...); +} + +// Calculate the unroll factor given: +// * bytePerPack: number of bytes accessed per instruction +// * insns: max permissible unroll value +// * bytes: desired number of in-flight bytes per iteration ( = unroll*bytePerPack) +__host__ __device__ constexpr int ncclCalcUnroll(int bytePerPack, int insns, int bytes) { + return min_constexpr(insns, (bytes + bytePerPack-1)/bytePerPack); +} + +// Note that all unroll value logic should depend on a given cudaArch argument +// and not __CUDA_ARCH__ since these need to be host-side executable where the +// arch value is strictly runtime only. By defaulting to NCCL_CUDA_ARCH, device +// side code can elide passing the arch for brevity. + +__host__ __device__ constexpr int ncclCollUnroll(int cudaArch = NCCL_CUDA_ARCH) { + // Our collective unroll should move to the same bytes&insns model as NVLS. + return cudaArch >= 800 ? 8 : 4; +} + +__host__ __device__ constexpr int ncclNvlsUnrollBytes(int cudaArch = NCCL_CUDA_ARCH) { return 4*16; } +__host__ __device__ constexpr int ncclNvlsUnrollInsns(int cudaArch = NCCL_CUDA_ARCH) { return 16; } + +__host__ __device__ constexpr int ncclNvlsUnroll(int bytePerPack, int cudaArch = NCCL_CUDA_ARCH) { + return ncclCalcUnroll(bytePerPack, ncclNvlsUnrollInsns(cudaArch), ncclNvlsUnrollBytes(cudaArch)); +} + +// The amount of dynamic shmem per warp +__host__ __device__ constexpr int ncclShmemScratchWarpSize(int cudaArch = NCCL_CUDA_ARCH) { + return (max_constexpr( + /*LL */0, + /*LL128 */(NCCL_LL128_SHMEM_ELEMS_PER_THREAD*WARP_SIZE)*sizeof(uint64_t), + /*SIMPLE*/(ncclCollUnroll(cudaArch)*WARP_SIZE + 1)*16, + // NVLS needs an extra 16B to read unaligned data. + /*NVLS */WARP_SIZE*(cudaArch >= 900 ? ncclNvlsUnrollBytes(cudaArch) : 0) + 16 + ) + 15) & -16; // pad to 16 bytes +} + +// The amount of dynamic shmem per block +__host__ __device__ constexpr int ncclShmemDynamicSize(int cudaArch = NCCL_CUDA_ARCH) { + return cudaArch < 700 ? 0 : ncclShmemScratchWarpSize(cudaArch)*(NCCL_MAX_NTHREADS/WARP_SIZE); +} + #endif diff --git a/src/include/enqueue.h b/src/include/enqueue.h index 7c15654e0f..634f037cb3 100644 --- a/src/include/enqueue.h +++ b/src/include/enqueue.h @@ -15,9 +15,7 @@ #define NCCL_MIN_CHANNEL_SIZE (NCCL_LL_THREAD_THRESHOLD*64) #define NCCL_AGG_CHANNEL_SIZE (1LL << 21) /* 2 MiB, ideal per-channel size to fully utilize bandwidth */ -size_t ncclKernMaxLocalSize(); -size_t ncclKernLocalSize(int i); -ncclResult_t ncclKernSetSharedMemoryCarveout(int carveOut); +ncclResult_t ncclInitKernelsForDevice(int cudaArch, size_t* maxStackSize); ncclResult_t ncclEnqueueCheck(struct ncclInfo* info); ncclResult_t ncclLaunchPrepare(struct ncclComm* comm); ncclResult_t ncclLaunchKernelBefore_NoUncapturedCuda(struct ncclComm* comm, struct ncclKernelPlan* plan); diff --git a/src/include/graph.h b/src/include/graph.h index 2ce4b4bffb..38b17d5113 100644 --- a/src/include/graph.h +++ b/src/include/graph.h @@ -119,9 +119,6 @@ ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePa ncclResult_t ncclTreeBasePostset(struct ncclComm* comm, struct ncclTopoGraph* treeGraph); -ncclResult_t ncclBinaryTreePostset(struct ncclComm* comm, struct ncclTopoGraph* treeGraph); -ncclResult_t ncclBinaryTreeHayabusaPostset(struct ncclComm* comm, struct ncclTopoGraph* treeGraph); - ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCompCap, struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph, struct ncclTopoGraph* collNetGraph); #include "info.h" ncclResult_t ncclTopoGetAlgoTime(struct ncclInfo* info, int algorithm, int protocol, int numPipeOps, float* time); diff --git a/src/include/info.h b/src/include/info.h index feaffb3f56..193d820f51 100644 --- a/src/include/info.h +++ b/src/include/info.h @@ -25,6 +25,7 @@ typedef enum : uint8_t { ncclPatternTreeUpDown, ncclPatternCollnetChain, ncclPatternCollnetDirect, + ncclPatternNvls, ncclPatternSend, ncclPatternRecv } ncclPattern_t; diff --git a/src/include/ipcsocket.h b/src/include/ipcsocket.h new file mode 100644 index 0000000000..700f0bcdeb --- /dev/null +++ b/src/include/ipcsocket.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2016-2023, NVIDIA CORPORATION. All rights reserved. + * + * See COPYRIGHT for license information + */ + +#ifndef NCCL_IPCSOCKET_H +#define NCCL_IPCSOCKET_H + +#include "nccl.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define NCCL_IPC_SOCKNAME_LEN 64 + +struct ncclIpcSocket { + int fd; + char socketName[NCCL_IPC_SOCKNAME_LEN]; + volatile uint32_t* abortFlag; +}; + +ncclResult_t ncclIpcSocketInit(struct ncclIpcSocket *handle, int rank, uint64_t hash, volatile uint32_t* abortFlag); +ncclResult_t ncclIpcSocketClose(struct ncclIpcSocket *handle); + +ncclResult_t ncclIpcSocketRecvFd(struct ncclIpcSocket *handle, int *fd); +ncclResult_t ncclIpcSocketSendFd(struct ncclIpcSocket *handle, const int fd, int rank, uint64_t hash); + +#endif /* NCCL_IPCSOCKET_H */ diff --git a/src/include/nccl_net.h b/src/include/nccl_net.h index 255a44ee28..a387e66d7a 100644 --- a/src/include/nccl_net.h +++ b/src/include/nccl_net.h @@ -20,7 +20,7 @@ #define NCCL_NET_MAX_REQUESTS 8 typedef enum {NCCL_LOG_NONE=0, NCCL_LOG_VERSION=1, NCCL_LOG_WARN=2, NCCL_LOG_INFO=3, NCCL_LOG_ABORT=4, NCCL_LOG_TRACE=5} ncclDebugLogLevel; -typedef enum {NCCL_INIT=1, NCCL_COLL=2, NCCL_P2P=4, NCCL_SHM=8, NCCL_NET=16, NCCL_GRAPH=32, NCCL_TUNING=64, NCCL_ENV=128, NCCL_ALLOC=256, NCCL_CALL=512, NCCL_ALL=~0} ncclDebugLogSubSys; +typedef enum {NCCL_INIT=1, NCCL_COLL=2, NCCL_P2P=4, NCCL_SHM=8, NCCL_NET=16, NCCL_GRAPH=32, NCCL_TUNING=64, NCCL_ENV=128, NCCL_ALLOC=256, NCCL_CALL=512, NCCL_PROXY=1024, NCCL_NVLS=2048, NCCL_ALL=~0} ncclDebugLogSubSys; typedef void (*ncclDebugLogger_t)(ncclDebugLogLevel level, unsigned long flags, const char *file, int line, const char *fmt, ...); diff --git a/src/include/nvtx.h b/src/include/nvtx.h index 2aeb93286d..ab32ef27f0 100644 --- a/src/include/nvtx.h +++ b/src/include/nvtx.h @@ -7,12 +7,12 @@ #ifndef NCCL_NVTX_H_ #define NCCL_NVTX_H_ -#include "nvtx3.hpp" +#include "nvtx3/nvtx3.hpp" -#if __cpp_constexpr >= 201304L && !defined(NVTX3_RELAXED_CONSTEXPR) -#define NVTX3_RELAXED_CONSTEXPR constexpr +#if __cpp_constexpr >= 201304L && !defined(NVTX3_CONSTEXPR_IF_CPP14) +#define NVTX3_CONSTEXPR_IF_CPP14 constexpr #else -#define NVTX3_RELAXED_CONSTEXPR +#define NVTX3_CONSTEXPR_IF_CPP14 #endif // Define all NCCL-provided static schema IDs here (avoid duplicates). @@ -37,7 +37,7 @@ struct nccl_domain{static constexpr char const* name{"NCCL"};}; class payload_schema { public: - NVTX3_RELAXED_CONSTEXPR explicit payload_schema(const nvtxPayloadSchemaEntry_t entries[], size_t numEntries, const uint64_t schemaId, const char* schemaName = nullptr) noexcept + explicit payload_schema(const nvtxPayloadSchemaEntry_t entries[], size_t numEntries, const uint64_t schemaId, const char* schemaName = nullptr) noexcept { schema_attr.name = schemaName; schema_attr.entries = entries; @@ -74,11 +74,11 @@ class payload_schema { #define NVTX3_FUNC_WITH_PARAMS(ID, S, P) \ static const payload_schema schema{S, std::extent::value, \ NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_STATIC_START + NVTX_SID_##ID, #ID}; \ - static ::nvtx3::v1::registered_string const nvtx3_func_name__{__func__}; \ + static ::nvtx3::v1::registered_string_in const nvtx3_func_name__{__func__}; \ nvtxPayloadData_t nvtx3_bpl__[] = { \ {NVTX_PAYLOAD_ENTRY_TYPE_SCHEMA_ID_STATIC_START + NVTX_SID_##ID, sizeof(P), &(P)}}; \ - ::nvtx3::v1::event_attributes nvtx3_func_attr__{nvtx3_func_name__, nvtx3_bpl__}; \ - ::nvtx3::v1::domain_thread_range const nvtx3_range__{nvtx3_func_attr__}; + ::nvtx3::v1::event_attributes const nvtx3_func_attr__{nvtx3_func_name__, nvtx3_bpl__}; \ + ::nvtx3::v1::scoped_range_in const nvtx3_range__{nvtx3_func_attr__}; extern void initNvtxRegisteredEnums(); diff --git a/src/include/nvtx3/nvToolsExt.h b/src/include/nvtx3/nvToolsExt.h index ce4b0be651..10938385d3 100644 --- a/src/include/nvtx3/nvToolsExt.h +++ b/src/include/nvtx3/nvToolsExt.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3/nvToolsExtCuda.h b/src/include/nvtx3/nvToolsExtCuda.h index b1e654c746..b1b80ad671 100644 --- a/src/include/nvtx3/nvToolsExtCuda.h +++ b/src/include/nvtx3/nvToolsExtCuda.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3/nvToolsExtCudaRt.h b/src/include/nvtx3/nvToolsExtCudaRt.h index 002f6e9975..1e19958ec9 100644 --- a/src/include/nvtx3/nvToolsExtCudaRt.h +++ b/src/include/nvtx3/nvToolsExtCudaRt.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3/nvToolsExtOpenCL.h b/src/include/nvtx3/nvToolsExtOpenCL.h index 611c0cb07f..a7b8a19b0c 100644 --- a/src/include/nvtx3/nvToolsExtOpenCL.h +++ b/src/include/nvtx3/nvToolsExtOpenCL.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3/nvToolsExtPayload.h b/src/include/nvtx3/nvToolsExtPayload.h index 1683f929d7..a46c833e2f 100644 --- a/src/include/nvtx3/nvToolsExtPayload.h +++ b/src/include/nvtx3/nvToolsExtPayload.h @@ -1,12 +1,12 @@ /* -* Copyright 2021 NVIDIA Corporation. All rights reserved. +* Copyright 2021-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception */ -#include "nvtx3/nvToolsExt.h" +#include "nvToolsExt.h" #ifndef NVTOOLSEXT_PAYLOAD_H #define NVTOOLSEXT_PAYLOAD_H diff --git a/src/include/nvtx3/nvToolsExtSync.h b/src/include/nvtx3/nvToolsExtSync.h index 5d2472962d..113fcd1910 100644 --- a/src/include/nvtx3/nvToolsExtSync.h +++ b/src/include/nvtx3/nvToolsExtSync.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3.hpp b/src/include/nvtx3/nvtx3.hpp similarity index 51% rename from src/include/nvtx3.hpp rename to src/include/nvtx3/nvtx3.hpp index 353fddfd21..cb0ef6858f 100644 --- a/src/include/nvtx3.hpp +++ b/src/include/nvtx3/nvtx3.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,15 +20,15 @@ /* This section handles the decision of whether to provide unversioned symbols. * If NVTX3_CPP_REQUIRE_EXPLICIT_VERSION is #defined, unversioned symbols are - * not provided, and explicit-version symbols such as nvtx3::v1::thread_range + * not provided, and explicit-version symbols such as nvtx3::v1::scoped_range * and NVTX3_V1_FUNC_RANGE must be used. By default, the first #include of this - * header will define the unversioned symbols such as nvtx3::thread_range and + * header will define the unversioned symbols such as nvtx3::scoped_range and * NVTX3_FUNC_RANGE. Subsequently including a different major version of this * header without #defining NVTX3_CPP_REQUIRE_EXPLICIT_VERSION triggers an error * since the symbols would conflict. Subsequently including of a different * minor version within the same major version is allowed. Functionality of * minor versions is cumulative, regardless of include order. - * + * * Since NVTX3_CPP_REQUIRE_EXPLICIT_VERSION allows all combinations of versions * to coexist without problems within a translation unit, the recommended best * practice for instrumenting header-based libraries with NVTX C++ Wrappers is @@ -39,66 +39,58 @@ */ /* clang-format off */ #if !defined(NVTX3_CPP_REQUIRE_EXPLICIT_VERSION) - /* Define macro used by all definitions in this header to indicate the - * unversioned symbols should be defined in addition to the versioned ones. + /* Define macro used by all definitions in this header to indicate the + * unversioned symbols should be defined in addition to the versioned ones. + */ + #define NVTX3_INLINE_THIS_VERSION + + #if !defined(NVTX3_CPP_INLINED_VERSION_MAJOR) + /* First occurrence of this header in the translation unit. Define macros + * indicating which version shall be used for unversioned symbols. */ - #define NVTX3_INLINE_THIS_VERSION - #if !defined(NVTX3_CPP_INLINED_VERSION_MAJOR) - /* First occurrence of this header in the translation unit. Define macros - * indicating which version shall be used for unversioned symbols. - */ + /** + * @brief Semantic major version number for NVTX C++ wrappers of unversioned symbols + * + * Breaking changes may occur between major versions, and different major versions + * cannot provide unversioned symbols in the same translation unit (.cpp file). + * + * Note: If NVTX3_CPP_REQUIRE_EXPLICIT_VERSION is defined, this macro is not defined. + * + * Not to be confused with the version number of the NVTX core library. + */ + #define NVTX3_CPP_INLINED_VERSION_MAJOR 1 // NVTX3_CPP_VERSION_MAJOR - /** - * @brief Semantic major version number for NVTX C++ wrappers of unversioned symbols - * - * Breaking changes may occur between major versions, and different major versions - * cannot provide unversioned symbols in the same translation unit (.cpp file). - * - * Note: If NVTX3_CPP_REQUIRE_EXPLICIT_VERSION is defined, this macro is not defined. - * - * Not to be confused with the version number of the NVTX core library. - */ - #define NVTX3_CPP_INLINED_VERSION_MAJOR 1 // NVTX3_CPP_VERSION_MAJOR - - /** - * @brief Semantic minor version number for NVTX C++ wrappers of unversioned symbols - * - * No breaking changes occur between minor versions -- minor version changes within - * a major version are purely additive. - * - * Note: If NVTX3_CPP_REQUIRE_EXPLICIT_VERSION is defined, this macro is not defined. - * - * Not to be confused with the version number of the NVTX core library. - */ - #define NVTX3_CPP_INLINED_VERSION_MINOR 0 // NVTX3_CPP_VERSION_MINOR - #elif NVTX3_CPP_INLINED_VERSION_MAJOR != NVTX3_CPP_VERSION_MAJOR - /* Unsupported case -- cannot define unversioned symbols for different major versions - * in the same translation unit. - */ - #error \ - "Two different major versions of the NVTX C++ Wrappers are being included in a single .cpp file, with unversioned symbols enabled in both. Only one major version can enable unversioned symbols in a .cpp file. To disable unversioned symbols, #define NVTX3_CPP_REQUIRE_EXPLICIT_VERSION before #including nvtx3.hpp, and use the explicit-version symbols instead -- this is the preferred way to use nvtx3.hpp from a header file." - #elif (NVTX3_CPP_INLINED_VERSION_MAJOR == NVTX3_CPP_VERSION_MAJOR) && \ - (NVTX3_CPP_INLINED_VERSION_MINOR < NVTX3_CPP_VERSION_MINOR) - /* An older minor version of the same major version already defined unversioned - * symbols. The new features provided in this header will be inlined - * redefine the minor version macro to this header's version. - */ - #undef NVTX3_CPP_INLINED_VERSION_MINOR - #define NVTX3_CPP_INLINED_VERSION_MINOR 0 // NVTX3_CPP_VERSION_MINOR - // else, already have this version or newer, nothing to do - #endif + /** + * @brief Semantic minor version number for NVTX C++ wrappers of unversioned symbols + * + * No breaking changes occur between minor versions -- minor version changes within + * a major version are purely additive. + * + * Note: If NVTX3_CPP_REQUIRE_EXPLICIT_VERSION is defined, this macro is not defined. + * + * Not to be confused with the version number of the NVTX core library. + */ + #define NVTX3_CPP_INLINED_VERSION_MINOR 0 // NVTX3_CPP_VERSION_MINOR + #elif NVTX3_CPP_INLINED_VERSION_MAJOR != NVTX3_CPP_VERSION_MAJOR + /* Unsupported case -- cannot define unversioned symbols for different major versions + * in the same translation unit. + */ + #error \ + "Two different major versions of the NVTX C++ Wrappers are being included in a single .cpp file, with unversioned symbols enabled in both. Only one major version can enable unversioned symbols in a .cpp file. To disable unversioned symbols, #define NVTX3_CPP_REQUIRE_EXPLICIT_VERSION before #including nvtx3.hpp, and use the explicit-version symbols instead -- this is the preferred way to use nvtx3.hpp from a header file." + #elif (NVTX3_CPP_INLINED_VERSION_MAJOR == NVTX3_CPP_VERSION_MAJOR) && \ + (NVTX3_CPP_INLINED_VERSION_MINOR < NVTX3_CPP_VERSION_MINOR) + /* An older minor version of the same major version already defined unversioned + * symbols. The new features provided in this header will be inlined + * redefine the minor version macro to this header's version. + */ + #undef NVTX3_CPP_INLINED_VERSION_MINOR + #define NVTX3_CPP_INLINED_VERSION_MINOR 0 // NVTX3_CPP_VERSION_MINOR + // else, already have this version or newer, nothing to do + #endif #endif /* clang-format on */ -#include -#include - -#include -#include -#include -#include - /** * @file nvtx3.hpp * @@ -112,19 +104,19 @@ * * \section QUICK_START Quick Start * - * To add NVTX ranges to your code, use the `nvtx3::thread_range` RAII object. A + * To add NVTX ranges to your code, use the `nvtx3::scoped_range` RAII object. A * range begins when the object is created, and ends when the object is * destroyed. * * \code{.cpp} * #include "nvtx3.hpp" - * void some_function(){ + * void some_function() { * // Begins a NVTX range with the messsage "some_function" * // The range ends when some_function() returns and `r` is destroyed - * nvtx3::thread_range r{"some_function"}; + * nvtx3::scoped_range r{"some_function"}; * - * for(int i = 0; i < 6; ++i){ - * nvtx3::thread_range loop{"loop range"}; + * for(int i = 0; i < 6; ++i) { + * nvtx3::scoped_range loop{"loop range"}; * std::this_thread::sleep_for(std::chrono::seconds{1}); * } * } // Range ends when `r` is destroyed @@ -142,10 +134,9 @@ * * \code{.cpp} * #include "nvtx3.hpp" - * void some_function(){ + * void some_function() { * // Creates a range with a message "some_function" that ends when the - * enclosing - * // function returns + * // enclosing function returns * NVTX3_FUNC_RANGE(); * ... * } @@ -165,66 +156,66 @@ * be accomplished with an NVTX range created on the entry to the function and * terminated on return from `my_function` using the push/pop C APIs: * - * ``` - * void my_function(...){ + * \code{.cpp} + * void my_function(...) { * nvtxRangePushA("my_function"); // Begins NVTX range * // do work * nvtxRangePop(); // Ends NVTX range * } - * ``` + * \endcode * * One of the challenges with using the NVTX C API is that it requires manually * terminating the end of the range with `nvtxRangePop`. This can be challenging * if `my_function()` has multiple returns or can throw exceptions as it * requires calling `nvtxRangePop()` before all possible return points. * - * NVTX++ solves this inconvenience through the "RAII" technique by providing a - * `nvtx3::thread_range` class that begins a range at construction and ends the - * range on destruction. The above example then becomes: + * NVTX C++ solves this inconvenience through the "RAII" technique by providing + * a `nvtx3::scoped_range` class that begins a range at construction and ends + * the range on destruction. The above example then becomes: * - * ``` - * void my_function(...){ - * nvtx3::thread_range r{"my_function"}; // Begins NVTX range + * \code{.cpp} + * void my_function(...) { + * nvtx3::scoped_range r{"my_function"}; // Begins NVTX range * // do work * } // Range ends on exit from `my_function` when `r` is destroyed - * ``` + * \endcode * * The range object `r` is deterministically destroyed whenever `my_function` * returns---ending the NVTX range without manual intervention. For more - * information, see \ref RANGES and `nvtx3::domain_thread_range`. + * information, see \ref RANGES and `nvtx3::scoped_range_in`. * * Another inconvenience of the NVTX C APIs are the several constructs where the * user is expected to initialize an object at the beginning of an application * and reuse that object throughout the lifetime of the application. For example - * Domains, Categories, and Registered messages. + * see domains, categories, and registered messages. * * Example: - * ``` + * \code{.cpp} * nvtxDomainHandle_t D = nvtxDomainCreateA("my domain"); * // Reuse `D` throughout the rest of the application - * ``` + * \endcode * * This can be problematic if the user application or library does not have an * explicit initialization function called before all other functions to * ensure that these long-lived objects are initialized before being used. * - * NVTX++ makes use of the "construct on first use" technique to alleviate this - * inconvenience. In short, a function local static object is constructed upon - * the first invocation of a function and returns a reference to that object on - * all future invocations. See the documentation for - * `nvtx3::registered_string`, `nvtx3::domain`, `nvtx3::named_category`, and + * NVTX C++ makes use of the "construct on first use" technique to alleviate + * this inconvenience. In short, a function local static object is constructed + * upon the first invocation of a function and returns a reference to that + * object on all future invocations. See the documentation for `nvtx3::domain`, + * `nvtx3::named_category`, `nvtx3::registered_string`, and * https://isocpp.org/wiki/faq/ctors#static-init-order-on-first-use for more * information. * * Using construct on first use, the above example becomes: - * ``` + * \code{.cpp} * struct my_domain{ static constexpr char const* name{"my domain"}; }; * * // The first invocation of `domain::get` for the type `my_domain` will * // construct a `nvtx3::domain` object and return a reference to it. Future * // invocations simply return a reference. * nvtx3::domain const& D = nvtx3::domain::get(); - * ``` + * \endcode * For more information about NVTX and how it can be used, see * https://docs.nvidia.com/cuda/profiler-users-guide/index.html#nvtx and * https://devblogs.nvidia.com/cuda-pro-tip-generate-custom-application-profile-timelines-nvtx/ @@ -236,106 +227,108 @@ * application. Common examples are using ranges to annotate the time it takes * to execute a function or an iteration of a loop. * - * NVTX++ uses RAII to automate the generation of ranges that are tied to the + * NVTX C++ uses RAII to automate the generation of ranges that are tied to the * lifetime of objects. Similar to `std::lock_guard` in the C++ Standard * Template Library. * - * \subsection THREAD_RANGE Thread Range + * \subsection scoped_range Scoped Range * - * `nvtx3::domain_thread_range` is a class that begins a range upon construction + * `nvtx3::scoped_range_in` is a class that begins a range upon construction * and ends the range at destruction. This is one of the most commonly used - * constructs in NVTX++ and is useful for annotating spans of time on a + * constructs in NVTX C++ and is useful for annotating spans of time on a * particular thread. These ranges can be nested to arbitrary depths. * - * `nvtx3::thread_range` is an alias for a `nvtx3::domain_thread_range` in the + * `nvtx3::scoped_range` is an alias for a `nvtx3::scoped_range_in` in the * global NVTX domain. For more information about Domains, see \ref DOMAINS. * * Various attributes of a range can be configured constructing a - * `nvtx3::domain_thread_range` with a `nvtx3::event_attributes` object. For + * `nvtx3::scoped_range_in` with a `nvtx3::event_attributes` object. For * more information, see \ref ATTRIBUTES. * * Example: * * \code{.cpp} - * void some_function(){ + * void some_function() { * // Creates a range for the duration of `some_function` - * nvtx3::thread_range r{}; + * nvtx3::scoped_range r{}; * - * while(true){ + * while(true) { * // Creates a range for every loop iteration * // `loop_range` is nested inside `r` - * nvtx3::thread_range loop_range{}; + * nvtx3::scoped_range loop_range{}; * } * } * \endcode * - * \subsection PROCESS_RANGE Process Range + * \subsection unique_range Unique Range * - * `nvtx3::domain_process_range` is identical to `nvtx3::domain_thread_range` - * with the exception that a `domain_process_range` can be created and destroyed - * on different threads. This is useful to annotate spans of time that can - * bridge multiple threads. + * `nvtx3::unique_range` is similar to `nvtx3::scoped_range`, with a few key differences: + * - `unique_range` objects can be destroyed in any order whereas `scoped_range` objects must be + * destroyed in exact reverse creation order + * - `unique_range` can start and end on different threads + * - `unique_range` is moveable + * - `unique_range` objects can be constructed as heap objects * - * `nvtx3::domain_thread_range`s should be preferred unless one needs the - * ability to begin and end a range on different threads. + * There is extra overhead associated with `unique_range` constructs and therefore use of + * `nvtx3::scoped_range_in` should be preferred. * * \section MARKS Marks * - * `nvtx3::mark` allows annotating an instantaneous event in an application's - * timeline. For example, indicating when a mutex is locked or unlocked. + * `nvtx3::mark` annotates an instantaneous point in time with a "marker". + * + * Unlike a "range" which has a beginning and an end, a marker is a single event + * in an application, such as detecting a problem: * * \code{.cpp} - * std::mutex global_lock; - * void lock_mutex(){ - * global_lock.lock(); - * // Marks an event immediately after the mutex is locked - * nvtx3::mark("lock_mutex"); + * bool success = do_operation(...); + * if (!success) { + * nvtx3::mark("operation failed!"); * } * \endcode * * \section DOMAINS Domains * - * Similar to C++ namespaces, Domains allow for scoping NVTX events. By default, + * Similar to C++ namespaces, domains allow for scoping NVTX events. By default, * all NVTX events belong to the "global" domain. Libraries and applications * should scope their events to use a custom domain to differentiate where the * events originate from. * * It is common for a library or application to have only a single domain and * for the name of that domain to be known at compile time. Therefore, Domains - * in NVTX++ are represented by _tag types_. + * in NVTX C++ are represented by _tag types_. * - * For example, to define a custom domain, simply define a new concrete type + * For example, to define a custom domain, simply define a new concrete type * (a `class` or `struct`) with a `static` member called `name` that contains * the desired name of the domain. * - * ``` + * \code{.cpp} * struct my_domain{ static constexpr char const* name{"my domain"}; }; - * ``` + * \endcode * - * For any NVTX++ construct that can be scoped to a domain, the type `my_domain` - * can be passed as an explicit template argument to scope it to the custom - * domain. + * For any NVTX C++ construct that can be scoped to a domain, the type + * `my_domain` can be passed as an explicit template argument to scope it to + * the custom domain. * * The tag type `nvtx3::domain::global` represents the global NVTX domain. * * \code{.cpp} - * // By default, `domain_thread_range` belongs to the global domain - * nvtx3::domain_thread_range<> r0{}; + * // By default, `scoped_range_in` belongs to the global domain + * nvtx3::scoped_range_in<> r0{}; * - * // Alias for a `domain_thread_range` in the global domain - * nvtx3::thread_range r1{}; + * // Alias for a `scoped_range_in` in the global domain + * nvtx3::scoped_range r1{}; * * // `r` belongs to the custom domain - * nvtx3::domain_thread_range r{}; + * nvtx3::scoped_range_in r{}; * \endcode * - * When using a custom domain, it is reccomended to define type aliases for NVTX + * When using a custom domain, it is recommended to define type aliases for NVTX * constructs in the custom domain. - * ``` - * using my_thread_range = nvtx3::domain_thread_range; - * using my_registered_string = nvtx3::registered_string; - * using my_named_category = nvtx3::named_category; - * ``` + * \code{.cpp} + * using my_scoped_range = nvtx3::scoped_range_in; + * using my_registered_string = nvtx3::registered_string_in; + * using my_named_category = nvtx3::named_category_in; + * \endcode * * See `nvtx3::domain` for more information. * @@ -359,35 +352,41 @@ * information. * * \code{.cpp} - * // Custom color, message - * event_attributes attr{nvtx3::rgb{127, 255, 0}, - * "message"}; + * // Set message, same as passing nvtx3::message{"message"} + * nvtx3::event_attributes attr{"message"}; * - * // Custom color, message, payload, category - * event_attributes attr{nvtx3::rgb{127, 255, 0}, - * nvtx3::payload{42}, - * "message", - * nvtx3::category{1}}; + * // Set message and color + * nvtx3::event_attributes attr{"message", nvtx3::rgb{127, 255, 0}}; * - * // Arguments can be in any order - * event_attributes attr{nvtx3::payload{42}, - * nvtx3::category{1}, - * "message", - * nvtx3::rgb{127, 255, 0}}; + * // Set message, color, payload, category + * nvtx3::event_attributes attr{"message", + * nvtx3::rgb{127, 255, 0}, + * nvtx3::payload{42}, + * nvtx3::category{1}}; * - * // "First wins" with multiple arguments of the same type - * event_attributes attr{ nvtx3::payload{42}, nvtx3::payload{7} }; // payload is - * 42 \endcode + * // Same as above -- can use any order of arguments + * nvtx3::event_attributes attr{nvtx3::payload{42}, + * nvtx3::category{1}, + * "message", + * nvtx3::rgb{127, 255, 0}}; + * + * // Multiple arguments of the same type are allowed, but only the first is + * // used -- in this example, payload is set to 42: + * nvtx3::event_attributes attr{ nvtx3::payload{42}, nvtx3::payload{7} }; + * + * // Using the nvtx3 namespace in a local scope makes the syntax more succinct: + * using namespace nvtx3; + * event_attributes attr{"message", rgb{127, 255, 0}, payload{42}, category{1}}; + * \endcode * * \subsection MESSAGES message * - * A `nvtx3::message` allows associating a custom message string with an NVTX - * event. + * `nvtx3::message` sets the message string for an NVTX event. * * Example: * \code{.cpp} - * // Create an `event_attributes` with the custom message "my message" - * nvtx3::event_attributes attr{nvtx3::Mesage{"my message"}}; + * // Create an `event_attributes` with the message "my message" + * nvtx3::event_attributes attr{nvtx3::message{"my message"}}; * * // strings and string literals implicitly assumed to be a `nvtx3::message` * nvtx3::event_attributes attr{"my message"}; @@ -415,8 +414,8 @@ * * Example: * \code{.cpp} - * // Explicitly constructed, static `registered_string` - * static registered_string static_message{"my message"}; + * // Explicitly constructed, static `registered_string` in my_domain: + * static registered_string_in static_message{"my message"}; * * // Or use construct on first use: * // Define a tag type with a `message` member string to register @@ -424,8 +423,8 @@ * * // Uses construct on first use to register the contents of * // `my_message::message` - * nvtx3::registered_string const& msg = - * nvtx3::registered_string::get(); \endcode + * auto& msg = nvtx3::registered_string_in::get(); + * \endcode * * \subsection COLOR color * @@ -466,34 +465,32 @@ * custom tag type with static `name` and `id` members. * * \code{.cpp} - * // Explicitly constructed, static `named_category` - * static nvtx3::named_category static_category{42, "my category"}; + * // Explicitly constructed, static `named_category` in my_domain: + * static nvtx3::named_category_in static_category{42, "my category"}; * - * // OR use construct on first use: + * // Or use construct on first use: * // Define a tag type with `name` and `id` members - * struct my_category{ + * struct my_category { * static constexpr char const* name{"my category"}; // category name - * static constexpr category::id_type id{42}; // category id + * static constexpr uint32_t id{42}; // category id * }; * * // Use construct on first use to name the category id `42` - * // with name "my category" - * nvtx3::named_category const& my_category = - * named_category::get(); + * // with name "my category": + * auto& cat = named_category_in::get(); * * // Range `r` associated with category id `42` - * nvtx3::event_attributes attr{my_category}; + * nvtx3::event_attributes attr{cat}; * \endcode * * \subsection PAYLOAD payload * * Allows associating a user-defined numerical value with an event. * - * ``` - * nvtx3:: event_attributes attr{nvtx3::payload{42}}; // Constructs a payload - * from - * // the `int32_t` value 42 - * ``` + * \code{.cpp} + * // Constructs a payload from the `int32_t` value 42 + * nvtx3:: event_attributes attr{nvtx3::payload{42}}; + * \endcode * * * \section EXAMPLE Example @@ -513,34 +510,33 @@ * struct my_message{ static constexpr char const* message{"my message"}; }; * * // For convenience, use aliases for domain scoped objects - * using my_thread_range = nvtx3::domain_thread_range; - * using my_registered_string = nvtx3::registered_string; - * using my_named_category = nvtx3::named_category; + * using my_scoped_range = nvtx3::scoped_range_in; + * using my_registered_string = nvtx3::registered_string_in; + * using my_named_category = nvtx3::named_category_in; * * // Default values for all attributes * nvtx3::event_attributes attr{}; - * my_thread_range r0{attr}; + * my_scoped_range r0{attr}; * * // Custom (unregistered) message, and unnamed category * nvtx3::event_attributes attr1{"message", nvtx3::category{2}}; - * my_thread_range r1{attr1}; + * my_scoped_range r1{attr1}; * * // Alternatively, pass arguments of `event_attributes` ctor directly to - * // `my_thread_range` - * my_thread_range r2{"message", nvtx3::category{2}}; + * // `my_scoped_range` + * my_scoped_range r2{"message", nvtx3::category{2}}; * * // construct on first use a registered string - * auto msg = my_registered_string::get(); + * auto& msg = my_registered_string::get(); * * // construct on first use a named category - * auto category = my_named_category::get(); + * auto& cat = my_named_category::get(); * - * // Use registered string and named category - * my_thread_range r3{msg, category, nvtx3::rgb{127, 255, 0}, - * nvtx3::payload{42}}; + * // Use registered string and named category with a custom payload + * my_scoped_range r3{msg, cat, nvtx3::payload{42}}; * * // Any number of arguments in any order - * my_thread_range r{nvtx3::rgb{127, 255,0}, msg}; + * my_scoped_range r{nvtx3::rgb{127, 255,0}, msg}; * * \endcode * \section MACROS Convenience Macros @@ -550,11 +546,11 @@ * * A convenient way to do this is to use the \ref NVTX3_FUNC_RANGE and * \ref NVTX3_FUNC_RANGE_IN macros. These macros take care of constructing an - * `nvtx3::domain_thread_range` with the name of the enclosing function as the + * `nvtx3::scoped_range_in` with the name of the enclosing function as the * range's message. * * \code{.cpp} - * void some_function(){ + * void some_function() { * // Automatically generates an NVTX range for the duration of the function * // using "some_function" as the event's message. * NVTX3_FUNC_RANGE(); @@ -565,6 +561,25 @@ /* Temporary helper #defines, removed with #undef at end of header */ +#if !defined(NVTX3_USE_CHECKED_OVERLOADS_FOR_GET) +#if defined(_MSC_VER) && _MSC_VER < 1914 +/* Microsoft's compiler prior to VS2017 Update 7 (15.7) uses an older parser + * that does not work with domain::get's specialization for domain::global, + * and would require extra conditions to make SFINAE work for the overloaded + * get() functions. This macro disables use of overloaded get() in order to + * work with VS2015 and versions of VS2017 below 15.7, without penalizing + * users of newer compilers. Building with this flag set to 0 means errors + * when defining tag structs (see documentation for domain, named_category, + * and registered_string) will have more complex compiler error messages + * instead of the clear static_assert messages from the get() overloads. + */ +#define NVTX3_USE_CHECKED_OVERLOADS_FOR_GET 0 +#else +#define NVTX3_USE_CHECKED_OVERLOADS_FOR_GET 1 +#endif +#define NVTX3_USE_CHECKED_OVERLOADS_FOR_GET_DEFINED_HERE +#endif + /* Within this header, nvtx3::NVTX3_VERSION_NAMESPACE resolves to nvtx3::vX, * where "X" is the major version number. */ #define NVTX3_CONCAT(A, B) A##B @@ -580,18 +595,30 @@ #define NVTX3_INLINE_IF_REQUESTED #endif -/* Enables the use of constexpr when support for C++14 relaxed constexpr - * is present. +/* Enables the use of constexpr when support for C++14 constexpr is present. * - * Initializing a legacy-C (i.e., no constructor) union member requires - * initializing in the constructor body. Non-empty constexpr constructors - * require C++14 relaxed constexpr. In strict C++11 compilation, fall back - * to using non-constexpr constructors for classes with union members. + * Initialization of a class member that is a union to a specific union member + * can only be done in the body of a constructor, not in a member initializer + * list. A constexpr constructor must have an empty body until C++14, so there + * is no way to make an initializer of a member union constexpr in C++11. This + * macro allows making functions constexpr in C++14 or newer, but non-constexpr + * in C++11 compilation. It is used here on constructors that initialize their + * member unions. */ #if __cpp_constexpr >= 201304L -#define NVTX3_RELAXED_CONSTEXPR constexpr +#define NVTX3_CONSTEXPR_IF_CPP14 constexpr #else -#define NVTX3_RELAXED_CONSTEXPR +#define NVTX3_CONSTEXPR_IF_CPP14 +#endif + + /* Use a macro for static asserts, which defaults to static_assert, but that + * testing tools can replace with a logging function. For example: + * #define NVTX3_STATIC_ASSERT(c, m) \ + * do { if (!(c)) printf("static_assert would fail: %s\n", m); } while (0) + */ +#if !defined(NVTX3_STATIC_ASSERT) +#define NVTX3_STATIC_ASSERT(condition, message) static_assert(condition, message); +#define NVTX3_STATIC_ASSERT_DEFINED_HERE #endif /* Implementation sections, enclosed in guard macros for each minor version */ @@ -599,6 +626,15 @@ #ifndef NVTX3_CPP_DEFINITIONS_V1_0 #define NVTX3_CPP_DEFINITIONS_V1_0 +#include "nvToolsExt.h" +#include "nvToolsExtPayload.h" + +#include +#include +#include +#include +#include + namespace nvtx3 { NVTX3_INLINE_IF_REQUESTED namespace NVTX3_VERSION_NAMESPACE @@ -606,20 +642,35 @@ NVTX3_INLINE_IF_REQUESTED namespace NVTX3_VERSION_NAMESPACE namespace detail { -/** - * @brief Verifies if a type `T` contains a member `T::name` of type `const - * char*` or `const wchar_t*`. - * - * @tparam T The type to verify - * @return True if `T` contains a member `T::name` of type `const char*` or - * `const wchar_t*`. - */ +template +struct always_false : std::false_type {}; + +template +struct has_name : std::false_type {}; template -constexpr auto has_name_member() noexcept -> decltype(T::name, bool()) -{ - return (std::is_same::type>::value || - std::is_same::type>::value); -} +struct has_name : std::true_type {}; + +template +struct has_id : std::false_type {}; +template +struct has_id : std::true_type {}; + +template +struct has_message : std::false_type {}; +template +struct has_message : std::true_type {}; + +template +struct is_c_string : std::false_type {}; +template +struct is_c_string::value || + std::is_convertible::value +>::type> : std::true_type {}; + +template +using is_uint32 = std::is_same::type, uint32_t>; + } // namespace detail /** @@ -634,7 +685,7 @@ constexpr auto has_name_member() noexcept -> decltype(T::name, bool()) * `domain`s are expected to be long-lived and unique to a library or * application. As such, it is assumed a domain's name is known at compile * time. Therefore, all NVTX constructs that can be associated with a domain - * require the domain to be specified via a *type* `DomainName` passed as an + * require the domain to be specified via a *type* `D` passed as an * explicit template parameter. * * The type `domain::global` may be used to indicate that the global NVTX @@ -642,109 +693,46 @@ constexpr auto has_name_member() noexcept -> decltype(T::name, bool()) * * None of the C++ NVTX constructs require the user to manually construct a * `domain` object. Instead, if a custom domain is desired, the user is - * expected to define a type `DomainName` that contains a member - * `DomainName::name` which resolves to either a `char const*` or `wchar_t - * const*`. The value of `DomainName::name` is used to name and uniquely + * expected to define a type `D` that contains a member + * `D::name` which resolves to either a `char const*` or `wchar_t + * const*`. The value of `D::name` is used to name and uniquely * identify the custom domain. * * Upon the first use of an NVTX construct associated with the type - * `DomainName`, the "construct on first use" pattern is used to construct a + * `D`, the "construct on first use" pattern is used to construct a * function local static `domain` object. All future NVTX constructs - * associated with `DomainType` will use a reference to the previously + * associated with `D` will use a reference to the previously * constructed `domain` object. See `domain::get`. * * Example: - * ``` + * \code{.cpp} * // The type `my_domain` defines a `name` member used to name and identify - * the - * // `domain` object identified by `my_domain`. + * // the `domain` object identified by `my_domain`. * struct my_domain{ static constexpr char const* name{"my_domain"}; }; * * // The NVTX range `r` will be grouped with all other NVTX constructs * // associated with `my_domain`. - * nvtx3::domain_thread_range r{}; + * nvtx3::scoped_range_in r{}; * - * // An alias can be created for a `domain_thread_range` in the custom domain - * using my_thread_range = nvtx3::domain_thread_range; - * my_thread_range my_range{}; + * // An alias can be created for a `scoped_range_in` in the custom domain + * using my_scoped_range = nvtx3::scoped_range_in; + * my_scoped_range my_range{}; * * // `domain::global` indicates that the global NVTX domain is used - * nvtx3::domain_thread_range r2{}; + * nvtx3::scoped_range_in r2{}; * - * // For convenience, `nvtx3::thread_range` is an alias for a range in the + * // For convenience, `nvtx3::scoped_range` is an alias for a range in the * // global domain - * nvtx3::thread_range r3{}; - * ``` + * nvtx3::scoped_range r3{}; + * \endcode */ class domain { public: domain(domain const&) = delete; domain& operator=(domain const&) = delete; - domain(domain&&) = delete; + domain(domain&&) = delete; domain& operator=(domain&&) = delete; - /** - * @brief Returns reference to an instance of a function local static - * `domain` object. - * - * Uses the "construct on first use" idiom to safely ensure the `domain` - * object is initialized exactly once upon first invocation of - * `domain::get()`. All following invocations will return a - * reference to the previously constructed `domain` object. See - * https://isocpp.org/wiki/faq/ctors#static-init-order-on-first-use - * - * None of the constructs in this header require the user to directly invoke - * `domain::get`. It is automatically invoked when constructing objects like - * a `domain_thread_range` or `category`. Advanced users may wish to use - * `domain::get` for the convenience of the "construct on first use" idiom - * when using domains with their own use of the NVTX C API. - * - * This function is threadsafe as of C++11. If two or more threads call - * `domain::get` concurrently, exactly one of them is guaranteed - * to construct the `domain` object and the other(s) will receive a - * reference to the object after it is fully constructed. - * - * The domain's name is specified via the type `DomainName` pass as an - * explicit template parameter. `DomainName` is required to contain a - * member `DomainName::name` that resolves to either a `char const*` or - * `wchar_t const*`. The value of `DomainName::name` is used to name and - * uniquely identify the `domain`. - * - * Example: - * ``` - * // The type `my_domain` defines a `name` member used to name and identify - * // the `domain` object identified by `my_domain`. - * struct my_domain{ static constexpr char const* name{"my domain"}; }; - * - * auto D = domain::get(); // First invocation constructs a - * // `domain` with the name "my domain" - * - * auto D1 = domain::get(); // Simply returns reference to - * // previously constructed `domain`. - * ``` - * - * @tparam DomainName Type that contains a `DomainName::name` member used to - * name the `domain` object. - * @return Reference to the `domain` corresponding to the type `DomainName`. - */ - template - static domain const& get() - { - static_assert(detail::has_name_member(), - "Type used to identify a domain must contain a name member of" - "type const char* or const wchar_t*"); - static domain const d{DomainName::name}; - return d; - } - - /** - * @brief Conversion operator to `nvtxDomainHandle_t`. - * - * Allows transparently passing a domain object into an API expecting a - * native `nvtxDomainHandle_t` object. - */ - operator nvtxDomainHandle_t() const noexcept { return _domain; } - /** * @brief Tag type for the "global" NVTX domain. * @@ -759,6 +747,113 @@ class domain { struct global { }; +#if NVTX3_USE_CHECKED_OVERLOADS_FOR_GET + /** + * @brief Returns reference to an instance of a function local static + * `domain` object. + * + * Uses the "construct on first use" idiom to safely ensure the `domain` + * object is initialized exactly once upon first invocation of + * `domain::get()`. All following invocations will return a + * reference to the previously constructed `domain` object. See + * https://isocpp.org/wiki/faq/ctors#static-init-order-on-first-use + * + * None of the constructs in this header require the user to directly invoke + * `domain::get`. It is automatically invoked when constructing objects like + * a `scoped_range_in` or `category`. Advanced users may wish to use + * `domain::get` for the convenience of the "construct on first use" idiom + * when using domains with their own use of the NVTX C API. + * + * This function is threadsafe as of C++11. If two or more threads call + * `domain::get` concurrently, exactly one of them is guaranteed + * to construct the `domain` object and the other(s) will receive a + * reference to the object after it is fully constructed. + * + * The domain's name is specified via the type `D` pass as an + * explicit template parameter. `D` is required to contain a + * member `D::name` that resolves to either a `char const*` or + * `wchar_t const*`. The value of `D::name` is used to name and + * uniquely identify the `domain`. + * + * Example: + * \code{.cpp} + * // The type `my_domain` defines a `name` member used to name and identify + * // the `domain` object identified by `my_domain`. + * struct my_domain{ static constexpr char const* name{"my domain"}; }; + * + * auto& D1 = domain::get(); // First invocation constructs a + * // `domain` with the name "my domain" + * + * auto& D2 = domain::get(); // Quickly returns reference to + * // previously constructed `domain`. + * \endcode + * + * @tparam D Type that contains a `D::name` member used to + * name the `domain` object. + * @return Reference to the `domain` corresponding to the type `D`. + */ + template ::value + , int>::type = 0> + static domain const& get() noexcept + { + static domain const d(D::name); + return d; + } + + /** + * @brief Overload of `domain::get` to provide a clear compile error when + * `D` has a `name` member that is not directly convertible to either + * `char const*` or `wchar_t const*`. + */ + template ::value + , int>::type = 0> + static domain const& get() noexcept + { + NVTX3_STATIC_ASSERT(detail::always_false::value, + "Type used to identify an NVTX domain must contain a static constexpr member " + "called 'name' of type const char* or const wchar_t* -- 'name' member is not " + "convertible to either of those types"); + static domain const unused; + return unused; // Function must compile for static_assert to be triggered + } + + /** + * @brief Overload of `domain::get` to provide a clear compile error when + * `D` does not have a `name` member. + */ + template ::value + , int>::type = 0> + static domain const& get() noexcept + { + NVTX3_STATIC_ASSERT(detail::always_false::value, + "Type used to identify an NVTX domain must contain a static constexpr member " + "called 'name' of type const char* or const wchar_t* -- 'name' member is missing"); + static domain const unused; + return unused; // Function must compile for static_assert to be triggered + } +#else + template + static domain const& get() noexcept + { + static domain const d(D::name); + return d; + } +#endif + + /** + * @brief Conversion operator to `nvtxDomainHandle_t`. + * + * Allows transparently passing a domain object into an API expecting a + * native `nvtxDomainHandle_t` object. + */ + operator nvtxDomainHandle_t() const noexcept { return _domain; } + private: /** * @brief Construct a new domain with the specified `name`. @@ -808,7 +903,7 @@ class domain { * "global" NVTX domain. * */ - domain() = default; + domain() noexcept {} /** * @brief Intentionally avoid calling nvtxDomainDestroy on the `domain` object. @@ -844,15 +939,15 @@ class domain { * */ template <> -inline domain const& domain::get() +inline domain const& domain::get() noexcept { static domain const d{}; return d; } /** - * @brief Indicates the values of the red, green, blue color channels for - * a rgb color code. + * @brief Indicates the values of the red, green, and blue color channels for + * an RGB color to use as an event attribute (assumes no transparency). * */ struct rgb { @@ -869,19 +964,22 @@ struct rgb { * @param green_ Value of the green channel * @param blue_ Value of the blue channel */ - constexpr rgb(component_type red_, component_type green_, component_type blue_) noexcept + constexpr rgb( + component_type red_, + component_type green_, + component_type blue_) noexcept : red{red_}, green{green_}, blue{blue_} { } - component_type const red{}; ///< Red channel value - component_type const green{}; ///< Green channel value - component_type const blue{}; ///< Blue channel value + component_type red{}; ///< Red channel value + component_type green{}; ///< Green channel value + component_type blue{}; ///< Blue channel value }; /** * @brief Indicates the value of the alpha, red, green, and blue color - * channels for an argb color code. + * channels for an ARGB color to use as an event attribute. * */ struct argb final : rgb { @@ -897,15 +995,16 @@ struct argb final : rgb { * @param blue_ Value of the blue channel * */ - constexpr argb(component_type alpha_, - component_type red_, - component_type green_, - component_type blue_) noexcept + constexpr argb( + component_type alpha_, + component_type red_, + component_type green_, + component_type blue_) noexcept : rgb{red_, green_, blue_}, alpha{alpha_} { } - component_type const alpha{}; ///< Alpha channel value + component_type alpha{}; ///< Alpha channel value }; /** @@ -947,8 +1046,8 @@ class color { * * @param argb The alpha, red, green, blue components of the desired `color` */ - constexpr color(argb argb) noexcept - : color{from_bytes_msb_to_lsb(argb.alpha, argb.red, argb.green, argb.blue)} + constexpr color(argb argb_) noexcept + : color{from_bytes_msb_to_lsb(argb_.alpha, argb_.red, argb_.green, argb_.blue)} { } @@ -960,8 +1059,8 @@ class color { * * @param rgb The red, green, blue components of the desired `color` */ - constexpr color(rgb rgb) noexcept - : color{from_bytes_msb_to_lsb(0xFF, rgb.red, rgb.green, rgb.blue)} + constexpr color(rgb rgb_) noexcept + : color{from_bytes_msb_to_lsb(0xFF, rgb_.red, rgb_.green, rgb_.blue)} { } @@ -977,11 +1076,11 @@ class color { */ constexpr nvtxColorType_t get_type() const noexcept { return _type; } - color() = delete; - ~color() = default; + color() = delete; + ~color() = default; color(color const&) = default; color& operator=(color const&) = default; - color(color&&) = default; + color(color&&) = default; color& operator=(color&&) = default; private: @@ -990,16 +1089,17 @@ class color { * most to least significant byte order. * */ - constexpr static value_type from_bytes_msb_to_lsb(uint8_t byte3, - uint8_t byte2, - uint8_t byte1, - uint8_t byte0) noexcept + constexpr static value_type from_bytes_msb_to_lsb( + uint8_t byte3, + uint8_t byte2, + uint8_t byte1, + uint8_t byte0) noexcept { return uint32_t{byte3} << 24 | uint32_t{byte2} << 16 | uint32_t{byte1} << 8 | uint32_t{byte0}; } - value_type const _value{}; ///< color's argb color code - nvtxColorType_t const _type{NVTX_COLOR_ARGB}; ///< NVTX color type code + value_type _value{}; ///< color's argb color code + nvtxColorType_t _type{NVTX_COLOR_ARGB}; ///< NVTX color type code }; /** @@ -1014,10 +1114,10 @@ class color { * nvtx3::category cat1{1}; * * // Range `r1` belongs to the category identified by the value `1`. - * nvtx3::thread_range r1{cat1}; + * nvtx3::scoped_range r1{cat1}; * * // Range `r2` belongs to the same category as `r1` - * nvtx3::thread_range r2{nvtx3::category{1}}; + * nvtx3::scoped_range r2{nvtx3::category{1}}; * \endcode * * To associate a name string with a category id, see `named_category`. @@ -1033,7 +1133,7 @@ class category { * * The `category` will be unnamed and identified only by its `id` value. * - * All `category` objects sharing the same `id` are equivalent. + * All `category`s in a domain sharing the same `id` are equivalent. * * @param[in] id The `category`'s identifying value */ @@ -1045,15 +1145,15 @@ class category { */ constexpr id_type get_id() const noexcept { return id_; } - category() = delete; - ~category() = default; + category() = delete; + ~category() = default; category(category const&) = default; category& operator=(category const&) = default; - category(category&&) = default; + category(category&&) = default; category& operator=(category&&) = default; private: - id_type const id_{}; ///< category's unique identifier + id_type id_{}; ///< category's unique identifier }; /** @@ -1075,45 +1175,46 @@ class category { * * Example: * \code{.cpp} - * // Explicitly constructed, static `named_category` + * // Explicitly constructed, static `named_category` in global domain: * static nvtx3::named_category static_category{42, "my category"}; * * // Range `r` associated with category id `42` - * nvtx3::thread_range r{static_category}; + * nvtx3::scoped_range r{static_category}; * * // OR use construct on first use: * * // Define a type with `name` and `id` members - * struct my_category{ + * struct my_category { * static constexpr char const* name{"my category"}; // category name - * static constexpr category::id_type id{42}; // category id + * static constexpr uint32_t id{42}; // category id * }; * * // Use construct on first use to name the category id `42` * // with name "my category" - * auto my_category = named_category::get(); + * auto& cat = named_category_in::get(); * * // Range `r` associated with category id `42` - * nvtx3::thread_range r{my_category}; + * nvtx3::scoped_range r{cat}; * \endcode * - * `named_category`'s association of a name to a category id is local to the - * domain specified by the type `D`. An id may have a different name in + * `named_category_in`'s association of a name to a category id is local to + * the domain specified by the type `D`. An id may have a different name in * another domain. * * @tparam D Type containing `name` member used to identify the `domain` to - * which the `named_category` belongs. Else, `domain::global` to indicate + * which the `named_category_in` belongs. Else, `domain::global` to indicate * that the global NVTX domain should be used. */ template -class named_category final : public category { +class named_category_in final : public category { public: +#if NVTX3_USE_CHECKED_OVERLOADS_FOR_GET /** - * @brief Returns a global instance of a `named_category` as a + * @brief Returns a global instance of a `named_category_in` as a * function-local static. * - * Creates a `named_category` with name and id specified by the contents of - * a type `C`. `C::name` determines the name and `C::id` determines the + * Creates a `named_category_in` with name and id specified by the contents + * of a type `C`. `C::name` determines the name and `C::id` determines the * category id. * * This function is useful for constructing a named `category` exactly once @@ -1122,36 +1223,97 @@ class named_category final : public category { * Example: * \code{.cpp} * // Define a type with `name` and `id` members - * struct my_category{ + * struct my_category { * static constexpr char const* name{"my category"}; // category name * static constexpr uint32_t id{42}; // category id * }; * * // Use construct on first use to name the category id `42` * // with name "my category" - * auto cat = named_category::get(); + * auto& cat = named_category_in::get(); * * // Range `r` associated with category id `42` - * nvtx3::thread_range r{cat}; + * nvtx3::scoped_range r{cat}; * \endcode * * Uses the "construct on first use" idiom to safely ensure the `category` * object is initialized exactly once. See * https://isocpp.org/wiki/faq/ctors#static-init-order-on-first-use * - * @tparam C Type containing a member `C::name` that resolves to either a + * @tparam C Type containing a member `C::name` that resolves to either a * `char const*` or `wchar_t const*` and `C::id`. */ - template - static named_category const& get() noexcept + template ::value && + detail::is_uint32::value + , int>::type = 0> + static named_category_in const& get() noexcept { - static_assert(detail::has_name_member(), - "Type used to name a category must contain a name member."); - static named_category const category{C::id, C::name}; - return category; + static named_category_in const cat(C::id, C::name); + return cat; } + /** - * @brief Construct a `category` with the specified `id` and `name`. + * @brief Overload of `named_category_in::get` to provide a clear compile error + * when `C` has the required `name` and `id` members, but they are not the + * required types. `name` must be directly convertible to `char const*` or + * `wchar_t const*`, and `id` must be `uint32_t`. + */ + template ::value || + !detail::is_uint32::value + , int>::type = 0> + static named_category_in const& get() noexcept + { + NVTX3_STATIC_ASSERT(detail::is_c_string::value, + "Type used to name an NVTX category must contain a static constexpr member " + "called 'name' of type const char* or const wchar_t* -- 'name' member is not " + "convertible to either of those types"); + NVTX3_STATIC_ASSERT(detail::is_uint32::value, + "Type used to name an NVTX category must contain a static constexpr member " + "called 'id' of type uint32_t -- 'id' member is the wrong type"); + static named_category_in const unused; + return unused; // Function must compile for static_assert to be triggered + } + + /** + * @brief Overload of `named_category_in::get` to provide a clear compile error + * when `C` does not have the required `name` and `id` members. + */ + template ::value || + !detail::has_id::value + , int>::type = 0> + static named_category_in const& get() noexcept + { + NVTX3_STATIC_ASSERT(detail::has_name::value, + "Type used to name an NVTX category must contain a static constexpr member " + "called 'name' of type const char* or const wchar_t* -- 'name' member is missing"); + NVTX3_STATIC_ASSERT(detail::has_id::value, + "Type used to name an NVTX category must contain a static constexpr member " + "called 'id' of type uint32_t -- 'id' member is missing"); + static named_category_in const unused; + return unused; // Function must compile for static_assert to be triggered + } +#else + template + static named_category_in const& get() noexcept + { + static named_category_in const cat(C::id, C::name); + return cat; + } +#endif + + private: + // Default constructor is only used internally for static_assert(false) cases. + named_category_in() noexcept : category{0} {} + + public: + /** + * @brief Construct a `named_category_in` with the specified `id` and `name`. * * The name `name` will be registered with `id`. * @@ -1160,7 +1322,7 @@ class named_category final : public category { * @param[in] id The category id to name * @param[in] name The name to associated with `id` */ - named_category(id_type id, char const* name) noexcept : category{id} + named_category_in(id_type id, char const* name) noexcept : category{id} { #ifndef NVTX_DISABLE nvtxDomainNameCategoryA(domain::get(), get_id(), name); @@ -1171,7 +1333,7 @@ class named_category final : public category { }; /** - * @brief Construct a `category` with the specified `id` and `name`. + * @brief Construct a `named_category_in` with the specified `id` and `name`. * * The name `name` will be registered with `id`. * @@ -1180,7 +1342,7 @@ class named_category final : public category { * @param[in] id The category id to name * @param[in] name The name to associated with `id` */ - named_category(id_type id, wchar_t const* name) noexcept : category{id} + named_category_in(id_type id, wchar_t const* name) noexcept : category{id} { #ifndef NVTX_DISABLE nvtxDomainNameCategoryW(domain::get(), get_id(), name); @@ -1191,6 +1353,12 @@ class named_category final : public category { }; }; +/** + * @brief Alias for a `named_category_in` in the global NVTX domain. + * + */ +using named_category = named_category_in; + /** * @brief A message registered with NVTX. * @@ -1205,16 +1373,16 @@ class named_category final : public category { * * A particular message should only be registered once and the handle * reused throughout the rest of the application. This can be done by either - * explicitly creating static `registered_string` objects, or using the - * `registered_string::get` construct on first use helper (recommended). + * explicitly creating static `registered_string_in` objects, or using the + * `registered_string_in::get` construct on first use helper (recommended). * * Example: * \code{.cpp} - * // Explicitly constructed, static `registered_string` - * static registered_string static_message{"message"}; + * // Explicitly constructed, static `registered_string` in my_domain: + * static registered_string_in static_message{"message"}; * * // "message" is associated with the range `r` - * nvtx3::thread_range r{static_message}; + * nvtx3::scoped_range r{static_message}; * * // Or use construct on first use: * @@ -1224,30 +1392,31 @@ class named_category final : public category { * * // Uses construct on first use to register the contents of * // `my_message::message` - * auto msg = registered_string::get(); + * auto& msg = registered_string_in::get(); * * // "my message" is associated with the range `r` - * nvtx3::thread_range r{msg}; + * nvtx3::scoped_range r{msg}; * \endcode * - * `registered_string`s are local to a particular domain specified via + * `registered_string_in`s are local to a particular domain specified via * the type `D`. * * @tparam D Type containing `name` member used to identify the `domain` to - * which the `registered_string` belongs. Else, `domain::global` to indicate + * which the `registered_string_in` belongs. Else, `domain::global` to indicate * that the global NVTX domain should be used. */ template -class registered_string { +class registered_string_in { public: +#if NVTX3_USE_CHECKED_OVERLOADS_FOR_GET /** - * @brief Returns a global instance of a `registered_string` as a function + * @brief Returns a global instance of a `registered_string_in` as a function * local static. * * Provides a convenient way to register a message with NVTX without having * to explicitly register the message. * - * Upon first invocation, constructs a `registered_string` whose contents + * Upon first invocation, constructs a `registered_string_in` whose contents * are specified by `message::message`. * * All future invocations will return a reference to the object constructed @@ -1262,26 +1431,74 @@ class registered_string { * * // Uses construct on first use to register the contents of * // `my_message::message` - * auto msg = registered_string::get(); + * auto& msg = registered_string_in::get(); * * // "my message" is associated with the range `r` - * nvtx3::thread_range r{msg}; + * nvtx3::scoped_range r{msg}; * \endcode * * @tparam M Type required to contain a member `M::message` that * resolves to either a `char const*` or `wchar_t const*` used as the * registered string's contents. - * @return Reference to a `registered_string` associated with the type `M`. + * @return Reference to a `registered_string_in` associated with the type `M`. */ - template - static registered_string const& get() noexcept + template ::value + , int>::type = 0> + static registered_string_in const& get() noexcept { - static registered_string const registered_string{M::message}; - return registered_string; + static registered_string_in const regstr(M::message); + return regstr; } /** - * @brief Constructs a `registered_string` from the specified `msg` string. + * @brief Overload of `registered_string_in::get` to provide a clear compile error + * when `M` has a `message` member that is not directly convertible to either + * `char const*` or `wchar_t const*`. + */ + template ::value + , int>::type = 0> + static registered_string_in const& get() noexcept + { + NVTX3_STATIC_ASSERT(detail::always_false::value, + "Type used to register an NVTX string must contain a static constexpr member " + "called 'message' of type const char* or const wchar_t* -- 'message' member is " + "not convertible to either of those types"); + static registered_string_in const unused; + return unused; // Function must compile for static_assert to be triggered + } + + /** + * @brief Overload of `registered_string_in::get` to provide a clear compile error when + * `M` does not have a `message` member. + */ + template ::value + , int>::type = 0> + static registered_string_in const& get() noexcept + { + NVTX3_STATIC_ASSERT(detail::always_false::value, + "Type used to register an NVTX string must contain a static constexpr member " + "called 'message' of type const char* or const wchar_t* -- 'message' member " + "is missing"); + static registered_string_in const unused; + return unused; // Function must compile for static_assert to be triggered + } +#else + template + static registered_string_in const& get() noexcept + { + static registered_string_in const regstr(M::message); + return regstr; + } +#endif + + /** + * @brief Constructs a `registered_string_in` from the specified `msg` string. * * Registers `msg` with NVTX and associates a handle with the registered * message. @@ -1291,13 +1508,13 @@ class registered_string { * * @param msg The contents of the message */ - explicit registered_string(char const* msg) noexcept + explicit registered_string_in(char const* msg) noexcept : handle_{nvtxDomainRegisterStringA(domain::get(), msg)} { } /** - * @brief Constructs a `registered_string` from the specified `msg` string. + * @brief Constructs a `registered_string_in` from the specified `msg` string. * * Registers `msg` with NVTX and associates a handle with the registered * message. @@ -1307,10 +1524,11 @@ class registered_string { * * @param msg The contents of the message */ - explicit registered_string(std::string const& msg) noexcept : registered_string{msg.c_str()} {} + explicit registered_string_in(std::string const& msg) noexcept + : registered_string_in{msg.c_str()} {} /** - * @brief Constructs a `registered_string` from the specified `msg` string. + * @brief Constructs a `registered_string_in` from the specified `msg` string. * * Registers `msg` with NVTX and associates a handle with the registered * message. @@ -1320,13 +1538,13 @@ class registered_string { * * @param msg The contents of the message */ - explicit registered_string(wchar_t const* msg) noexcept + explicit registered_string_in(wchar_t const* msg) noexcept : handle_{nvtxDomainRegisterStringW(domain::get(), msg)} { } /** - * @brief Constructs a `registered_string` from the specified `msg` string. + * @brief Constructs a `registered_string_in` from the specified `msg` string. * * Registers `msg` with NVTX and associates a handle with the registered * message. @@ -1336,7 +1554,8 @@ class registered_string { * * @param msg The contents of the message */ - explicit registered_string(std::wstring const& msg) noexcept : registered_string{msg.c_str()} {} + explicit registered_string_in(std::wstring const& msg) noexcept + : registered_string_in{msg.c_str()} {} /** * @brief Returns the registered string's handle @@ -1344,18 +1563,27 @@ class registered_string { */ nvtxStringHandle_t get_handle() const noexcept { return handle_; } - registered_string() = delete; - ~registered_string() = default; - registered_string(registered_string const&) = default; - registered_string& operator=(registered_string const&) = default; - registered_string(registered_string&&) = default; - registered_string& operator=(registered_string&&) = default; +private: + // Default constructor is only used internally for static_assert(false) cases. + registered_string_in() noexcept {}; +public: + ~registered_string_in() = default; + registered_string_in(registered_string_in const&) = default; + registered_string_in& operator=(registered_string_in const&) = default; + registered_string_in(registered_string_in&&) = default; + registered_string_in& operator=(registered_string_in&&) = default; private: - nvtxStringHandle_t const handle_{}; ///< The handle returned from - ///< registering the message with NVTX + nvtxStringHandle_t handle_{}; ///< The handle returned from + ///< registering the message with NVTX }; +/** + * @brief Alias for a `registered_string_in` in the global NVTX domain. + * + */ +using registered_string = registered_string_in; + /** * @brief Allows associating a message string with an NVTX event via * its `EventAttribute`s. @@ -1374,7 +1602,7 @@ class registered_string { * nvtx3::event_attributes attr0{nvtx3::message{"message 0"}}; * * // `range0` contains message "message 0" - * nvtx3::thread_range range0{attr0}; + * nvtx3::scoped_range range0{attr0}; * * // `std::string` and string literals are implicitly assumed to be * // the contents of an `nvtx3::message` @@ -1382,15 +1610,15 @@ class registered_string { * nvtx3::event_attributes attr1{"message 1"}; * * // `range1` contains message "message 1" - * nvtx3::thread_range range1{attr1}; + * nvtx3::scoped_range range1{attr1}; * * // `range2` contains message "message 2" - * nvtx3::thread_range range2{nvtx3::Mesage{"message 2"}}; + * nvtx3::scoped_range range2{nvtx3::Mesage{"message 2"}}; * * // `std::string` and string literals are implicitly assumed to be * // the contents of an `nvtx3::message` * // `range3` contains message "message 3" - * nvtx3::thread_range range3{"message 3"}; + * nvtx3::scoped_range range3{"message 3"}; * \endcode */ class message { @@ -1402,7 +1630,7 @@ class message { * * @param msg The contents of the message */ - NVTX3_RELAXED_CONSTEXPR message(char const* msg) noexcept : type_{NVTX_MESSAGE_TYPE_ASCII} + NVTX3_CONSTEXPR_IF_CPP14 message(char const* msg) noexcept : type_{NVTX_MESSAGE_TYPE_ASCII} { value_.ascii = msg; } @@ -1429,7 +1657,7 @@ class message { * * @param msg The contents of the message */ - NVTX3_RELAXED_CONSTEXPR message(wchar_t const* msg) noexcept : type_{NVTX_MESSAGE_TYPE_UNICODE} + NVTX3_CONSTEXPR_IF_CPP14 message(wchar_t const* msg) noexcept : type_{NVTX_MESSAGE_TYPE_UNICODE} { value_.unicode = msg; } @@ -1452,35 +1680,59 @@ class message { message(std::wstring&&) = delete; /** - * @brief Construct a `message` from a `registered_string`. + * @brief Construct a `message` from a `registered_string_in`. * * @tparam D Type containing `name` member used to identify the `domain` - * to which the `registered_string` belongs. Else, `domain::global` to + * to which the `registered_string_in` belongs. Else, `domain::global` to * indicate that the global NVTX domain should be used. * @param msg The message that has already been registered with NVTX. */ template - NVTX3_RELAXED_CONSTEXPR message(registered_string const& msg) noexcept + NVTX3_CONSTEXPR_IF_CPP14 message(registered_string_in const& msg) noexcept : type_{NVTX_MESSAGE_TYPE_REGISTERED} { value_.registered = msg.get_handle(); } + /** + * @brief Construct a `message` from NVTX C API type and value. + * + * @param type nvtxMessageType_t enum value indicating type of the payload + * @param value nvtxMessageValue_t union containing message + */ + constexpr message( + nvtxMessageType_t const& type, + nvtxMessageValue_t const& value) noexcept + : type_{type}, value_(value) + { + } + + /** + * @brief Construct a `message` from NVTX C API registered string handle. + * + * @param handle nvtxStringHandle_t value of registered string handle + */ + NVTX3_CONSTEXPR_IF_CPP14 message(nvtxStringHandle_t handle) noexcept + : type_{NVTX_MESSAGE_TYPE_REGISTERED} + { + value_.registered = handle; + } + /** * @brief Return the union holding the value of the message. * */ - NVTX3_RELAXED_CONSTEXPR value_type get_value() const noexcept { return value_; } + constexpr value_type get_value() const noexcept { return value_; } /** * @brief Return the type information about the value the union holds. * */ - NVTX3_RELAXED_CONSTEXPR nvtxMessageType_t get_type() const noexcept { return type_; } + constexpr nvtxMessageType_t get_type() const noexcept { return type_; } private: - nvtxMessageType_t const type_{}; ///< message type - nvtxMessageValue_t value_{}; ///< message contents + nvtxMessageType_t type_{}; ///< message type + nvtxMessageValue_t value_{}; ///< message contents }; /** @@ -1488,17 +1740,16 @@ class message { * its `event_attributes`. * * Example: - * ``` - * nvtx3:: event_attributes attr{nvtx3::payload{42}}; // Constructs a payload - * from - * // the `int32_t` value 42 + * \code{.cpp} + * // Constructs a payload from the int32_t value 42 + * nvtx3:: event_attributes attr{nvtx3::payload{42}}; * * // `range0` will have an int32_t payload of 42 - * nvtx3::thread_range range0{attr}; + * nvtx3::scoped_range range0{attr}; * * // range1 has double payload of 3.14 - * nvtx3::thread_range range1{ nvtx3::payload{3.14} }; - * ``` + * nvtx3::scoped_range range1{nvtx3::payload{3.14}}; + * \endcode */ class payload { public: @@ -1509,7 +1760,7 @@ class payload { * * @param value Value to use as contents of the payload */ - NVTX3_RELAXED_CONSTEXPR explicit payload(int64_t value) noexcept + NVTX3_CONSTEXPR_IF_CPP14 explicit payload(int64_t value) noexcept : type_{NVTX_PAYLOAD_TYPE_INT64}, value_{} { value_.llValue = value; @@ -1520,7 +1771,7 @@ class payload { * * @param value Value to use as contents of the payload */ - NVTX3_RELAXED_CONSTEXPR explicit payload(int32_t value) noexcept + NVTX3_CONSTEXPR_IF_CPP14 explicit payload(int32_t value) noexcept : type_{NVTX_PAYLOAD_TYPE_INT32}, value_{} { value_.iValue = value; @@ -1531,7 +1782,7 @@ class payload { * * @param value Value to use as contents of the payload */ - NVTX3_RELAXED_CONSTEXPR explicit payload(uint64_t value) noexcept + NVTX3_CONSTEXPR_IF_CPP14 explicit payload(uint64_t value) noexcept : type_{NVTX_PAYLOAD_TYPE_UNSIGNED_INT64}, value_{} { value_.ullValue = value; @@ -1542,7 +1793,7 @@ class payload { * * @param value Value to use as contents of the payload */ - NVTX3_RELAXED_CONSTEXPR explicit payload(uint32_t value) noexcept + NVTX3_CONSTEXPR_IF_CPP14 explicit payload(uint32_t value) noexcept : type_{NVTX_PAYLOAD_TYPE_UNSIGNED_INT32}, value_{} { value_.uiValue = value; @@ -1554,7 +1805,7 @@ class payload { * * @param value Value to use as contents of the payload */ - NVTX3_RELAXED_CONSTEXPR explicit payload(float value) noexcept + NVTX3_CONSTEXPR_IF_CPP14 explicit payload(float value) noexcept : type_{NVTX_PAYLOAD_TYPE_FLOAT}, value_{} { value_.fValue = value; @@ -1566,27 +1817,40 @@ class payload { * * @param value Value to use as contents of the payload */ - NVTX3_RELAXED_CONSTEXPR explicit payload(double value) noexcept + NVTX3_CONSTEXPR_IF_CPP14 explicit payload(double value) noexcept : type_{NVTX_PAYLOAD_TYPE_DOUBLE}, value_{} { value_.dValue = value; } + /** + * @brief Construct a `payload` from NVTX C API type and value. + * + * @param type nvtxPayloadType_t enum value indicating type of the payload + * @param value nvtxEventAttributes_t::payload_t union containing payload + */ + constexpr payload( + nvtxPayloadType_t const& type, + value_type const& value) noexcept + : type_{type}, value_(value) + { + } + /** * @brief Return the union holding the value of the payload * */ - NVTX3_RELAXED_CONSTEXPR value_type get_value() const noexcept { return value_; } + constexpr value_type get_value() const noexcept { return value_; } /** * @brief Return the information about the type the union holds. * */ - NVTX3_RELAXED_CONSTEXPR nvtxPayloadType_t get_type() const noexcept { return type_; } + constexpr nvtxPayloadType_t get_type() const noexcept { return type_; } private: - nvtxPayloadType_t const type_; ///< Type of the payload value - value_type value_; ///< Union holding the payload value + nvtxPayloadType_t type_; ///< Type of the payload value + value_type value_; ///< Union holding the payload value }; /** @@ -1611,42 +1875,39 @@ class payload { * * Example: * \code{.cpp} - * event_attributes attr{}; // No arguments, use defaults for all attributes + * // Set message, same as using nvtx3::message{"message"} + * event_attributes attr{"message"}; * - * event_attributes attr{"message"}; // Custom message, rest defaulted - * - * // Custom color & message + * // Set message and color * event_attributes attr{"message", nvtx3::rgb{127, 255, 0}}; * - * /// Custom color & message, can use any order of arguments - * event_attributes attr{nvtx3::rgb{127, 255, 0}, "message"}; + * // Set message, color, payload, category + * event_attributes attr{"message", + * nvtx3::rgb{127, 255, 0}, + * nvtx3::payload{42}, + * nvtx3::category{1}}; * - * - * // Custom color, message, payload, category - * event_attributes attr{nvtx3::rgb{127, 255, 0}, - * "message", - * nvtx3::payload{42}, - * nvtx3::category{1}}; - * - * // Custom color, message, payload, category, can use any order of arguments + * // Same as above -- can use any order of arguments * event_attributes attr{nvtx3::payload{42}, - * nvtx3::category{1}, - * "message", - * nvtx3::rgb{127, 255, 0}}; + * nvtx3::category{1}, + * "message", + * nvtx3::rgb{127, 255, 0}}; * * // Multiple arguments of the same type are allowed, but only the first is - * // used. All others are ignored - * event_attributes attr{ nvtx3::payload{42}, nvtx3::payload{7} }; // payload - * is 42 + * // used -- in this example, payload is set to 42: + * event_attributes attr{ nvtx3::payload{42}, nvtx3::payload{7} }; * * // Range `r` will be customized according the attributes in `attr` - * nvtx3::thread_range r{attr}; + * nvtx3::scoped_range r{attr}; * - * // For convenience, the arguments that can be passed to the - * `event_attributes` - * // constructor may be passed to the `domain_thread_range` contructor where - * // they will be forwarded to the `EventAttribute`s constructor - * nvtx3::thread_range r{nvtx3::payload{42}, nvtx3::category{1}, "message"}; + * // For convenience, `event_attributes` constructor arguments may be passed + * // to the `scoped_range_in` contructor -- they are forwarded to the + * // `event_attributes` constructor + * nvtx3::scoped_range r{nvtx3::payload{42}, nvtx3::category{1}, "message"}; + * + * // Using the nvtx3 namespace in a local scope makes the syntax more succinct: + * using namespace nvtx3; + * scoped_range r{payload{42}, category{1}, "message"}; * \endcode * */ @@ -1682,7 +1943,7 @@ class event_attributes { * */ template - NVTX3_RELAXED_CONSTEXPR explicit event_attributes(category const& c, Args const&... args) noexcept + NVTX3_CONSTEXPR_IF_CPP14 explicit event_attributes(category const& c, Args const&... args) noexcept : event_attributes(args...) { attributes_.category = c.get_id(); @@ -1696,7 +1957,7 @@ class event_attributes { * */ template - NVTX3_RELAXED_CONSTEXPR explicit event_attributes(color const& c, Args const&... args) noexcept + NVTX3_CONSTEXPR_IF_CPP14 explicit event_attributes(color const& c, Args const&... args) noexcept : event_attributes(args...) { attributes_.color = c.get_value(); @@ -1711,7 +1972,7 @@ class event_attributes { * */ template - NVTX3_RELAXED_CONSTEXPR explicit event_attributes(payload const& p, Args const&... args) noexcept + NVTX3_CONSTEXPR_IF_CPP14 explicit event_attributes(payload const& p, Args const&... args) noexcept : event_attributes(args...) { attributes_.payload = p.get_value(); @@ -1726,14 +1987,14 @@ class event_attributes { * */ template - NVTX3_RELAXED_CONSTEXPR explicit event_attributes(message const& m, Args const&... args) noexcept + NVTX3_CONSTEXPR_IF_CPP14 explicit event_attributes(message const& m, Args const&... args) noexcept : event_attributes(args...) { attributes_.message = m.get_value(); attributes_.messageType = m.get_type(); } - /** + /** * @brief Variadic constructor where the first argument is a binary payload. * * Sets the value of the `EventAttribute`s message based on `m` and forwards @@ -1741,7 +2002,7 @@ class event_attributes { * */ template - NVTX3_RELAXED_CONSTEXPR explicit event_attributes(nvtxPayloadData_t const* bpl, Args const&... args) noexcept + NVTX3_CONSTEXPR_IF_CPP14 explicit event_attributes(nvtxPayloadData_t const* bpl, Args const&... args) noexcept : event_attributes(args...) { attributes_.payloadType = NVTX_PAYLOAD_TYPE_BINARY; @@ -1749,10 +2010,10 @@ class event_attributes { attributes_.payload.ullValue = NVTX_POINTER_AS_PAYLOAD_ULLVALUE(bpl); } - ~event_attributes() = default; + ~event_attributes() = default; event_attributes(event_attributes const&) = default; event_attributes& operator=(event_attributes const&) = default; - event_attributes(event_attributes&&) = default; + event_attributes(event_attributes&&) = default; event_attributes& operator=(event_attributes&&) = default; /** @@ -1772,16 +2033,16 @@ class event_attributes { * When constructed, begins a nested NVTX range on the calling thread in the * specified domain. Upon destruction, ends the NVTX range. * - * Behavior is undefined if a `domain_thread_range` object is + * Behavior is undefined if a `scoped_range_in` object is * created/destroyed on different threads. * - * `domain_thread_range` is neither moveable nor copyable. + * `scoped_range_in` is neither moveable nor copyable. * - * `domain_thread_range`s may be nested within other ranges. + * `scoped_range_in`s may be nested within other ranges. * * The domain of the range is specified by the template type parameter `D`. * By default, the `domain::global` is used, which scopes the range to the - * global NVTX domain. The convenience alias `thread_range` is provided for + * global NVTX domain. The convenience alias `scoped_range` is provided for * ranges scoped to the global domain. * * A custom domain can be defined by creating a type, `D`, with a static @@ -1789,48 +2050,47 @@ class event_attributes { * `D`. `D::name` must resolve to either `char const*` or `wchar_t const*` * * Example: - * ``` + * \code{.cpp} * // Define a type `my_domain` with a member `name` used to name the domain * // associated with the type `my_domain`. * struct my_domain{ - * static constexpr const char * name{"my domain"}; + * static constexpr char const* name{"my domain"}; * }; - * ``` + * \endcode * * Usage: - * ``` - * nvtx3::domain_thread_range<> r0{"range 0"}; // Range in global domain + * \code{.cpp} + * nvtx3::scoped_range_in r1{"range 1"}; // Range in my domain * - * nvtx3::thread_range r1{"range 1"}; // Alias for range in global domain + * // Three equivalent ways to make a range in the global domain: + * nvtx3::scoped_range_in r2{"range 2"}; + * nvtx3::scoped_range_in<> r3{"range 3"}; + * nvtx3::scoped_range r4{"range 4"}; * - * nvtx3::domain_thread_range r2{"range 2"}; // Range in custom - * domain + * // Create an alias to succinctly make ranges in my domain: + * using my_scoped_range = nvtx3::scoped_range_in; * - * // specify an alias to a range that uses a custom domain - * using my_thread_range = nvtx3::domain_thread_range; - * - * my_thread_range r3{"range 3"}; // Alias for range in custom domain - * ``` + * my_scoped_range r3{"range 3"}; + * \endcode */ template -class domain_thread_range { +class scoped_range_in { public: /** - * @brief Construct a `domain_thread_range` with the specified + * @brief Construct a `scoped_range_in` with the specified * `event_attributes` * * Example: - * ``` + * \code{cpp} * nvtx3::event_attributes attr{"msg", nvtx3::rgb{127,255,0}}; - * nvtx3::domain_thread_range<> range{attr}; // Creates a range with message - * contents - * // "msg" and green color - * ``` + * nvtx3::scoped_range range{attr}; // Creates a range with message contents + * // "msg" and green color + * \endcode * * @param[in] attr `event_attributes` that describes the desired attributes * of the range. */ - explicit domain_thread_range(event_attributes const& attr) noexcept + explicit scoped_range_in(event_attributes const& attr) noexcept { #ifndef NVTX_DISABLE nvtxDomainRangePushEx(domain::get(), attr.get()); @@ -1840,65 +2100,55 @@ class domain_thread_range { } /** - * @brief Constructs a `domain_thread_range` from the constructor arguments + * @brief Constructs a `scoped_range_in` from the constructor arguments * of an `event_attributes`. * - * Forwards the arguments `first, args...` to construct an + * Forwards the arguments `args...` to construct an * `event_attributes` object. The `event_attributes` object is then - * associated with the `domain_thread_range`. + * associated with the `scoped_range_in`. * * For more detail, see `event_attributes` documentation. * * Example: - * ``` + * \code{cpp} * // Creates a range with message "message" and green color - * nvtx3::domain_thread_range<> r{"message", nvtx3::rgb{127,255,0}}; - * ``` + * nvtx3::scoped_range r{"message", nvtx3::rgb{127,255,0}}; + * \endcode * - * @note To prevent making needless copies of `event_attributes` objects, - * this constructor is disabled when the first argument is an - * `event_attributes` object, instead preferring the explicit - * `domain_thread_range(event_attributes const&)` constructor. - * - * @param[in] first First argument to forward to the `event_attributes` - * constructor. - * @param[in] args Variadic parameter pack of additional arguments to - * forward. + * @param[in] args Arguments to used to construct an `event_attributes` associated with this + * range. * */ - template >::value>> - explicit domain_thread_range(First const& first, Args const&... args) noexcept - : domain_thread_range{event_attributes{first, args...}} + template + explicit scoped_range_in(Args const&... args) noexcept + : scoped_range_in{event_attributes{args...}} { } /** - * @brief Default constructor creates a `domain_thread_range` with no + * @brief Default constructor creates a `scoped_range_in` with no * message, color, payload, nor category. * */ - domain_thread_range() : domain_thread_range{event_attributes{}} {} + scoped_range_in() noexcept : scoped_range_in{event_attributes{}} {} /** * @brief Delete `operator new` to disallow heap allocated objects. * - * `domain_thread_range` must follow RAII semantics to guarantee proper push/pop semantics. + * `scoped_range_in` must follow RAII semantics to guarantee proper push/pop semantics. * */ void* operator new(std::size_t) = delete; - domain_thread_range(domain_thread_range const&) = delete; - domain_thread_range& operator=(domain_thread_range const&) = delete; - domain_thread_range(domain_thread_range&&) = delete; - domain_thread_range& operator=(domain_thread_range&&) = delete; + scoped_range_in(scoped_range_in const&) = delete; + scoped_range_in& operator=(scoped_range_in const&) = delete; + scoped_range_in(scoped_range_in&&) = delete; + scoped_range_in& operator=(scoped_range_in&&) = delete; /** - * @brief Destroy the domain_thread_range, ending the NVTX range event. + * @brief Destroy the scoped_range_in, ending the NVTX range event. */ - ~domain_thread_range() noexcept + ~scoped_range_in() noexcept { #ifndef NVTX_DISABLE nvtxDomainRangePop(domain::get()); @@ -1907,25 +2157,103 @@ class domain_thread_range { }; /** - * @brief Alias for a `domain_thread_range` in the global NVTX domain. + * @brief Alias for a `scoped_range_in` in the global NVTX domain. * */ -using thread_range = domain_thread_range<>; +using scoped_range = scoped_range_in; + +namespace detail { + +/// @cond internal +template +class optional_scoped_range_in +{ +public: + optional_scoped_range_in() = default; + + void begin(event_attributes const& attr) noexcept + { +#ifndef NVTX_DISABLE + // This class is not meant to be part of the public NVTX C++ API and should + // only be used in the `NVTX3_FUNC_RANGE_IF` and `NVTX3_FUNC_RANGE_IF_IN` + // macros. However, to prevent developers from misusing this class, make + // sure to not start multiple ranges. + if (initialized) { return; } + + nvtxDomainRangePushEx(domain::get(), attr.get()); + initialized = true; +#endif + } + + ~optional_scoped_range_in() noexcept + { +#ifndef NVTX_DISABLE + if (initialized) { nvtxDomainRangePop(domain::get()); } +#endif + } + + void* operator new(std::size_t) = delete; + optional_scoped_range_in(optional_scoped_range_in const&) = delete; + optional_scoped_range_in& operator=(optional_scoped_range_in const&) = delete; + optional_scoped_range_in(optional_scoped_range_in&&) = delete; + optional_scoped_range_in& operator=(optional_scoped_range_in&&) = delete; + +private: +#ifndef NVTX_DISABLE + bool initialized = false; +#endif +}; +/// @endcond + +} // namespace detail /** * @brief Handle used for correlating explicit range start and end events. * + * A handle is "null" if it does not correspond to any range. + * */ struct range_handle { /// Type used for the handle's value using value_type = nvtxRangeId_t; + /** * @brief Construct a `range_handle` from the given id. * */ constexpr explicit range_handle(value_type id) noexcept : _range_id{id} {} + /** + * @brief Constructs a null range handle. + * + * A null range_handle corresponds to no range. Calling `end_range` on a + * null handle is undefined behavior when a tool is active. + * + */ + constexpr range_handle() noexcept = default; + + /** + * @brief Checks whether this handle is null + * + * Provides contextual conversion to `bool`. + * + * \code{cpp} + * range_handle handle{}; + * if (handle) {...} + * \endcode + * + */ + constexpr explicit operator bool() const noexcept { return get_value() != null_range_id; }; + + /** + * @brief Implicit conversion from `nullptr` constructs a null handle. + * + * Satisfies the "NullablePointer" requirement to make `range_handle` comparable with `nullptr`. + * + */ + constexpr range_handle(std::nullptr_t) noexcept {} + /** * @brief Returns the `range_handle`'s value * @@ -1934,42 +2262,68 @@ struct range_handle { constexpr value_type get_value() const noexcept { return _range_id; } private: - value_type _range_id{}; ///< The underlying NVTX range id + /// Sentinel value for a null handle that corresponds to no range + static constexpr value_type null_range_id = nvtxRangeId_t{0}; + + value_type _range_id{null_range_id}; ///< The underlying NVTX range id }; +/** + * @brief Compares two range_handles for equality + * + * @param lhs The first range_handle to compare + * @param rhs The second range_handle to compare + */ +inline constexpr bool operator==(range_handle lhs, range_handle rhs) noexcept +{ + return lhs.get_value() == rhs.get_value(); +} + +/** + * @brief Compares two range_handles for inequality + * + * @param lhs The first range_handle to compare + * @param rhs The second range_handle to compare + */ +inline constexpr bool operator!=(range_handle lhs, range_handle rhs) noexcept { return !(lhs == rhs); } + /** * @brief Manually begin an NVTX range. * * Explicitly begins an NVTX range and returns a unique handle. To end the - * range, pass the handle to `end_range()`. + * range, pass the handle to `end_range_in()`. * - * `start_range/end_range` are the most explicit and lowest level APIs provided - * for creating ranges. Use of `nvtx3::domain_process_range` should be + * `nvtx3::start_range(...)` is equivalent to `nvtx3::start_range_in<>(...)` and + * `nvtx3::start_range_in(...)`. + * + * `start_range_in/end_range_in` are the most explicit and lowest level APIs + * provided for creating ranges. Use of `nvtx3::unique_range_in` should be * preferred unless one is unable to tie the range to the lifetime of an object. * * Example: - * ``` + * \code{.cpp} * nvtx3::event_attributes attr{"msg", nvtx3::rgb{127,255,0}}; - * nvtx3::range_handle h = nvxt3::start_range(attr); // Manually begins a range + * // Manually begin a range + * nvtx3::range_handle h = nvtx3::start_range_in(attr); * ... - * nvtx3::end_range(h); // Ends the range - * ``` + * nvtx3::end_range_in(h); // End the range + * \endcode * * @tparam D Type containing `name` member used to identify the `domain` * to which the range belongs. Else, `domain::global` to indicate that the * global NVTX domain should be used. * @param[in] attr `event_attributes` that describes the desired attributes * of the range. - * @return Unique handle to be passed to `end_range` to end the range. + * @return Unique handle to be passed to `end_range_in` to end the range. */ template -range_handle start_range(event_attributes const& attr) noexcept +inline range_handle start_range_in(event_attributes const& attr) noexcept { #ifndef NVTX_DISABLE return range_handle{nvtxDomainRangeStartEx(domain::get(), attr.get())}; #else (void)attr; - return range_handle{}; + return {}; #endif } @@ -1977,60 +2331,157 @@ range_handle start_range(event_attributes const& attr) noexcept * @brief Manually begin an NVTX range. * * Explicitly begins an NVTX range and returns a unique handle. To end the - * range, pass the handle to `end_range()`. + * range, pass the handle to `end_range_in()`. * - * Forwards the arguments `first, args...` to construct an `event_attributes` - * object. The `event_attributes` object is then associated with the range. + * `nvtx3::start_range(...)` is equivalent to `nvtx3::start_range_in<>(...)` and + * `nvtx3::start_range_in(...)`. * - * For more detail, see `event_attributes` documentation. - * - * Example: - * ``` - * nvtx3::range_handle h = nvxt3::start_range("msg", nvtx3::rgb{127,255,0}); // - * Begin range - * ... - * nvtx3::end_range(h); // Ends the range - * ``` - * - * `start_range/end_range` are the most explicit and lowest level APIs provided - * for creating ranges. Use of `nvtx3::domain_process_range` should be + * `start_range_in/end_range_in` are the most explicit and lowest level APIs + * provided for creating ranges. Use of `nvtx3::unique_range_in` should be * preferred unless one is unable to tie the range to the lifetime of an object. * - * @param first[in] First argument to pass to an `event_attributes` - * @param args[in] Variadiac parameter pack of the rest of the arguments for an - * `event_attributes`. + * This overload uses `args...` to construct an `event_attributes` to + * associate with the range. For more detail, see `event_attributes`. + * + * Example: + * \code{cpp} + * // Manually begin a range + * nvtx3::range_handle h = nvtx3::start_range_in("msg", nvtx3::rgb{127,255,0}); + * ... + * nvtx3::end_range_in(h); // Ends the range + * \endcode + * + * @tparam D Type containing `name` member used to identify the `domain` + * to which the range belongs. Else, `domain::global` to indicate that the + * global NVTX domain should be used. + * @param args[in] Variadic parameter pack of the arguments for an `event_attributes`. * @return Unique handle to be passed to `end_range` to end the range. */ -template >::value>> -range_handle start_range(First const& first, Args const&... args) noexcept +template +inline range_handle start_range_in(Args const&... args) noexcept { #ifndef NVTX_DISABLE - return start_range(event_attributes{first, args...}); + return start_range_in(event_attributes{args...}); #else - (void)first; - return range_handle{}; + return {}; #endif } /** - * @brief Manually end the range associated with the handle `r`. + * @brief Manually begin an NVTX range in the global domain. + * + * Explicitly begins an NVTX range and returns a unique handle. To end the + * range, pass the handle to `end_range()`. + * + * `nvtx3::start_range(...)` is equivalent to `nvtx3::start_range_in<>(...)` and + * `nvtx3::start_range_in(...)`. + * + * `start_range/end_range` are the most explicit and lowest level APIs + * provided for creating ranges. Use of `nvtx3::unique_range` should be + * preferred unless one is unable to tie the range to the lifetime of an object. + * + * Example: + * \code{.cpp} + * nvtx3::event_attributes attr{"msg", nvtx3::rgb{127,255,0}}; + * // Manually begin a range + * nvtx3::range_handle h = nvtx3::start_range(attr); + * ... + * nvtx3::end_range(h); // End the range + * \endcode + * + * @param[in] attr `event_attributes` that describes the desired attributes + * of the range. + * @return Unique handle to be passed to `end_range_in` to end the range. + */ +inline range_handle start_range(event_attributes const& attr) noexcept +{ +#ifndef NVTX_DISABLE + return start_range_in(attr); +#else + (void)attr; + return {}; +#endif +} + +/** + * @brief Manually begin an NVTX range in the global domain. + * + * Explicitly begins an NVTX range and returns a unique handle. To end the + * range, pass the handle to `end_range_in()`. + * + * `nvtx3::start_range(...)` is equivalent to `nvtx3::start_range_in<>(...)` and + * `nvtx3::start_range_in(...)`. + * + * `start_range_in/end_range_in` are the most explicit and lowest level APIs + * provided for creating ranges. Use of `nvtx3::unique_range_in` should be + * preferred unless one is unable to tie the range to the lifetime of an object. + * + * This overload uses `args...` to construct an `event_attributes` to + * associate with the range. For more detail, see `event_attributes`. + * + * Example: + * \code{cpp} + * // Manually begin a range + * nvtx3::range_handle h = nvtx3::start_range("msg", nvtx3::rgb{127,255,0}); + * ... + * nvtx3::end_range(h); // Ends the range + * \endcode + * + * @param args[in] Variadic parameter pack of the arguments for an `event_attributes`. + * @return Unique handle to be passed to `end_range` to end the range. + */ +template +inline range_handle start_range(Args const&... args) noexcept +{ +#ifndef NVTX_DISABLE + return start_range_in(args...); +#else + return {}; +#endif +} + +/** + * @brief Manually end the range associated with the handle `r` in domain `D`. + * + * Explicitly ends the NVTX range indicated by the handle `r` returned from a + * prior call to `start_range_in`. The range may end on a different thread + * from where it began. + * + * @tparam D Type containing `name` member used to identify the `domain` to + * which the range belongs. Else, `domain::global` to indicate that the global + * NVTX domain should be used. + * @param r Handle to a range started by a prior call to `start_range_in`. + * + * @warning The domain type specified as template parameter to this function + * must be the same that was specified on the associated `start_range_in` call. + */ +template +inline void end_range_in(range_handle r) noexcept +{ +#ifndef NVTX_DISABLE + nvtxDomainRangeEnd(domain::get(), r.get_value()); +#else + (void)r; +#endif +} + +/** + * @brief Manually end the range associated with the handle `r` in the global + * domain. * * Explicitly ends the NVTX range indicated by the handle `r` returned from a * prior call to `start_range`. The range may end on a different thread from * where it began. * - * This function does not have a Domain tag type template parameter as the - * handle `r` already indicates the domain to which the range belongs. - * * @param r Handle to a range started by a prior call to `start_range`. + * + * @warning The domain type specified as template parameter to this function + * must be the same that was specified on the associated `start_range` call. */ -inline void end_range(range_handle r) +inline void end_range(range_handle r) noexcept { #ifndef NVTX_DISABLE - nvtxRangeEnd(r.get_value()); + end_range_in(r); #else (void)r; #endif @@ -2043,120 +2494,145 @@ inline void end_range(range_handle r) * When constructed, begins a NVTX range in the specified domain. Upon * destruction, ends the NVTX range. * - * Similar to `nvtx3::domain_thread_range`, the only difference being that - * `domain_process_range` can start and end on different threads. + * Similar to `nvtx3::scoped_range_in`, with a few key differences: + * - `unique_range` objects can be destroyed in an order whereas `scoped_range` objects must be + * destroyed in exact reverse creation order + * - `unique_range` can start and end on different threads + * - `unique_range` is moveable + * - `unique_range` objects can be constructed as heap objects * - * Use of `nvtx3::domain_thread_range` should be preferred unless one needs - * the ability to start and end a range on different threads. - * - * `domain_process_range` is moveable, but not copyable. + * There is extra overhead associated with `unique_range` constructs and therefore use of + * `nvtx3::scoped_range_in` should be preferred. * * @tparam D Type containing `name` member used to identify the `domain` - * to which the `domain_process_range` belongs. Else, `domain::global` to + * to which the `unique_range_in` belongs. Else, `domain::global` to * indicate that the global NVTX domain should be used. */ template -class domain_process_range { +class unique_range_in { public: /** - * @brief Construct a new domain process range object + * @brief Construct a new unique_range_in object with the specified event attributes * - * @param attr + * Example: + * \code{cpp} + * nvtx3::event_attributes attr{"msg", nvtx3::rgb{127,255,0}}; + * nvtx3::unique_range_in range{attr}; // Creates a range with message contents + * // "msg" and green color + * \endcode + * + * @param[in] attr `event_attributes` that describes the desired attributes + * of the range. */ - explicit domain_process_range(event_attributes const& attr) noexcept - : handle_{new range_handle{start_range(attr)}} + explicit unique_range_in(event_attributes const& attr) noexcept + : handle_{start_range_in(attr)} { } /** - * @brief Construct a new domain process range object + * @brief Constructs a `unique_range_in` from the constructor arguments + * of an `event_attributes`. * - * @param first - * @param args + * Forwards the arguments `args...` to construct an + * `event_attributes` object. The `event_attributes` object is then + * associated with the `unique_range_in`. + * + * For more detail, see `event_attributes` documentation. + * + * Example: + * \code{.cpp} + * // Creates a range with message "message" and green color + * nvtx3::unique_range_in<> r{"message", nvtx3::rgb{127,255,0}}; + * \endcode + * + * @param[in] args Variadic parameter pack of arguments to construct an `event_attributes` + * associated with this range. */ - template >::value>> - explicit domain_process_range(First const& first, Args const&... args) noexcept - : domain_process_range{event_attributes{first, args...}} + template + explicit unique_range_in(Args const&... args) noexcept + : unique_range_in{event_attributes{args...}} { } /** - * @brief Construct a new domain process range object + * @brief Default constructor creates a `unique_range_in` with no + * message, color, payload, nor category. * */ - constexpr domain_process_range() noexcept : domain_process_range{event_attributes{}} {} + constexpr unique_range_in() noexcept : unique_range_in{event_attributes{}} {} /** - * @brief Destroy the `domain_process_range` ending the range. + * @brief Destroy the `unique_range_in` ending the range. * */ - ~domain_process_range() - { - if (handle_) { end_range(*handle_); } - } + ~unique_range_in() noexcept = default; /** * @brief Move constructor allows taking ownership of the NVTX range from - * another `domain_process_range`. + * another `unique_range_in`. * - * @param other + * @param other The range to take ownership of */ - domain_process_range(domain_process_range&& other) = default; + unique_range_in(unique_range_in&& other) noexcept = default; /** * @brief Move assignment operator allows taking ownership of an NVTX range - * from another `domain_process_range`. + * from another `unique_range_in`. * - * @param other - * @return domain_process_range& + * @param other The range to take ownership of */ - domain_process_range& operator=(domain_process_range&& other) = default; + unique_range_in& operator=(unique_range_in&& other) noexcept = default; /// Copy construction is not allowed to prevent multiple objects from owning /// the same range handle - domain_process_range(domain_process_range const&) = delete; + unique_range_in(unique_range_in const&) = delete; /// Copy assignment is not allowed to prevent multiple objects from owning the /// same range handle - domain_process_range& operator=(domain_process_range const&) = delete; + unique_range_in& operator=(unique_range_in const&) = delete; private: - std::unique_ptr handle_; ///< Range handle used to correlate - ///< the start/end of the range + + struct end_range_handle { + using pointer = range_handle; /// Override the pointer type of the unique_ptr + void operator()(range_handle h) const noexcept { end_range_in(h); } + }; + + /// Range handle used to correlate the start/end of the range + std::unique_ptr handle_; }; /** - * @brief Alias for a `domain_process_range` in the global NVTX domain. + * @brief Alias for a `unique_range_in` in the global NVTX domain. * */ -using process_range = domain_process_range<>; +using unique_range = unique_range_in; /** - * @brief Annotates an instantaneous point in time with the attributes specified - * by `attr`. + * @brief Annotates an instantaneous point in time with a "marker", using the + * attributes specified by `attr`. * - * Unlike a "range", a mark is an instantaneous event in an application, e.g., - * locking/unlocking a mutex. + * Unlike a "range" which has a beginning and an end, a marker is a single event + * in an application, such as detecting a problem: * * \code{.cpp} - * std::mutex global_lock; - * void lock_mutex(){ - * global_lock.lock(); - * nvtx3::mark("lock_mutex"); + * bool success = do_operation(...); + * if (!success) { + * nvtx3::event_attributes attr{"operation failed!", nvtx3::rgb{255,0,0}}; + * nvtx3::mark_in(attr); * } * \endcode * + * Note that nvtx3::mark_in is a function, not a class like scoped_range_in. + * * @tparam D Type containing `name` member used to identify the `domain` - * to which the `domain_process_range` belongs. Else, `domain::global` to + * to which the `unique_range_in` belongs. Else, `domain::global` to * indicate that the global NVTX domain should be used. * @param[in] attr `event_attributes` that describes the desired attributes * of the mark. */ template -inline void mark(event_attributes const& attr) noexcept +inline void mark_in(event_attributes const& attr) noexcept { #ifndef NVTX_DISABLE nvtxDomainMarkEx(domain::get(), attr.get()); @@ -2165,10 +2641,105 @@ inline void mark(event_attributes const& attr) noexcept #endif } +/** + * @brief Annotates an instantaneous point in time with a "marker", using the + * arguments to construct an `event_attributes`. + * + * Unlike a "range" which has a beginning and an end, a marker is a single event + * in an application, such as detecting a problem: + * + * \code{.cpp} + * bool success = do_operation(...); + * if (!success) { + * nvtx3::mark_in("operation failed!", nvtx3::rgb{255,0,0}); + * } + * \endcode + * + * Note that nvtx3::mark_in is a function, not a class like scoped_range_in. + * + * Forwards the arguments `args...` to construct an `event_attributes` object. + * The attributes are then associated with the marker. For more detail, see + * the `event_attributes` documentation. + * + * @tparam D Type containing `name` member used to identify the `domain` + * to which the `unique_range_in` belongs. Else `domain::global` to + * indicate that the global NVTX domain should be used. + * @param[in] args Variadic parameter pack of arguments to construct an `event_attributes` + * associated with this range. + * + */ +template +inline void mark_in(Args const&... args) noexcept +{ +#ifndef NVTX_DISABLE + mark_in(event_attributes{args...}); +#endif +} + +/** + * @brief Annotates an instantaneous point in time with a "marker", using the + * attributes specified by `attr`, in the global domain. + * + * Unlike a "range" which has a beginning and an end, a marker is a single event + * in an application, such as detecting a problem: + * + * \code{.cpp} + * bool success = do_operation(...); + * if (!success) { + * nvtx3::event_attributes attr{"operation failed!", nvtx3::rgb{255,0,0}}; + * nvtx3::mark(attr); + * } + * \endcode + * + * Note that nvtx3::mark is a function, not a class like scoped_range. + * + * @param[in] attr `event_attributes` that describes the desired attributes + * of the mark. + */ +inline void mark(event_attributes const& attr) noexcept +{ +#ifndef NVTX_DISABLE + mark_in(attr); +#endif +} + +/** + * @brief Annotates an instantaneous point in time with a "marker", using the + * arguments to construct an `event_attributes`, in the global domain. + * + * Unlike a "range" which has a beginning and an end, a marker is a single event + * in an application, such as detecting a problem: + * + * \code{.cpp} + * bool success = do_operation(...); + * if (!success) { + * nvtx3::mark("operation failed!", nvtx3::rgb{255,0,0}); + * } + * \endcode + * + * Note that nvtx3::mark is a function, not a class like scoped_range. + * + * Forwards the arguments `args...` to construct an `event_attributes` object. + * The attributes are then associated with the marker. For more detail, see + * the `event_attributes` documentation. + * + * @param[in] args Variadic parameter pack of arguments to construct an + * `event_attributes` associated with this range. + * + */ +template +inline void mark(Args const&... args) noexcept +{ +#ifndef NVTX_DISABLE + mark_in(args...); +#endif +} + } // namespace NVTX3_VERSION_NAMESPACE } // namespace nvtx3 +#ifndef NVTX_DISABLE /** * @brief Convenience macro for generating a range in the specified `domain` * from the lifetime of a function @@ -2177,34 +2748,58 @@ inline void mark(event_attributes const& attr) noexcept * the entry point of a function to its exit. It is intended to be the first * line of the function. * - * Constructs a static `registered_string` using the name of the immediately + * Constructs a static `registered_string_in` using the name of the immediately * enclosing function returned by `__func__` and constructs a - * `nvtx3::thread_range` using the registered function name as the range's + * `nvtx3::scoped_range` using the registered function name as the range's * message. * * Example: - * ``` + * \code{.cpp} * struct my_domain{static constexpr char const* name{"my_domain"};}; * - * void foo(...){ + * void foo(...) { * NVTX3_FUNC_RANGE_IN(my_domain); // Range begins on entry to foo() * // do stuff * ... * } // Range ends on return from foo() - * ``` + * \endcode * * @param[in] D Type containing `name` member used to identify the - * `domain` to which the `registered_string` belongs. Else, + * `domain` to which the `registered_string_in` belongs. Else, * `domain::global` to indicate that the global NVTX domain should be used. */ -#ifndef NVTX_DISABLE #define NVTX3_V1_FUNC_RANGE_IN(D) \ - static ::nvtx3::v1::registered_string const nvtx3_func_name__{__func__}; \ + static ::nvtx3::v1::registered_string_in const nvtx3_func_name__{__func__}; \ static ::nvtx3::v1::event_attributes const nvtx3_func_attr__{nvtx3_func_name__}; \ - ::nvtx3::v1::domain_thread_range const nvtx3_range__{nvtx3_func_attr__}; + ::nvtx3::v1::scoped_range_in const nvtx3_range__{nvtx3_func_attr__}; + +/** + * @brief Convenience macro for generating a range in the specified `domain` + * from the lifetime of a function if the given boolean expression evaluates + * to true. + * + * Similar to `NVTX3_V1_FUNC_RANGE_IN(D)`, the only difference being that + * `NVTX3_V1_FUNC_RANGE_IF_IN(D, C)` only generates a range if the given boolean + * expression evaluates to true. + * + * @param[in] D Type containing `name` member used to identify the + * `domain` to which the `registered_string_in` belongs. Else, + * `domain::global` to indicate that the global NVTX domain should be used. + * + * @param[in] C Boolean expression used to determine if a range should be + * generated. + */ +#define NVTX3_V1_FUNC_RANGE_IF_IN(D, C) \ + ::nvtx3::v1::detail::optional_scoped_range_in optional_nvtx3_range__; \ + if (C) { \ + static ::nvtx3::v1::registered_string_in const nvtx3_func_name__{__func__}; \ + static ::nvtx3::v1::event_attributes const nvtx3_func_attr__{nvtx3_func_name__}; \ + optional_nvtx3_range__.begin(nvtx3_func_attr__); \ + } #else #define NVTX3_V1_FUNC_RANGE_IN(D) -#endif +#define NVTX3_V1_FUNC_RANGE_IF_IN(D, C) +#endif // NVTX_DISABLE /** * @brief Convenience macro for generating a range in the global domain from the @@ -2214,28 +2809,43 @@ inline void mark(event_attributes const& attr) noexcept * the entry point of a function to its exit. It is intended to be the first * line of the function. * - * Constructs a static `registered_string` using the name of the immediately + * Constructs a static `registered_string_in` using the name of the immediately * enclosing function returned by `__func__` and constructs a - * `nvtx3::thread_range` using the registered function name as the range's + * `nvtx3::scoped_range` using the registered function name as the range's * message. * * Example: - * ``` - * void foo(...){ + * \code{.cpp} + * void foo(...) { * NVTX3_FUNC_RANGE(); // Range begins on entry to foo() * // do stuff * ... * } // Range ends on return from foo() - * ``` + * \endcode */ #define NVTX3_V1_FUNC_RANGE() NVTX3_V1_FUNC_RANGE_IN(::nvtx3::v1::domain::global) +/** + * @brief Convenience macro for generating a range in the global domain from the + * lifetime of a function if the given boolean expression evaluates to true. + * + * Similar to `NVTX3_V1_FUNC_RANGE()`, the only difference being that + * `NVTX3_V1_FUNC_RANGE_IF(C)` only generates a range if the given boolean + * expression evaluates to true. + * + * @param[in] C Boolean expression used to determine if a range should be + * generated. + */ +#define NVTX3_V1_FUNC_RANGE_IF(C) NVTX3_V1_FUNC_RANGE_IF_IN(::nvtx3::v1::domain::global, C) + /* When inlining this version, versioned macros must have unversioned aliases. * For each NVTX3_Vx_ #define, make an NVTX3_ alias of it here.*/ #if defined(NVTX3_INLINE_THIS_VERSION) /* clang format off */ -#define NVTX3_FUNC_RANGE_IN NVTX3_V1_FUNC_RANGE_IN -#define NVTX3_FUNC_RANGE NVTX3_V1_FUNC_RANGE +#define NVTX3_FUNC_RANGE NVTX3_V1_FUNC_RANGE +#define NVTX3_FUNC_RANGE_IF NVTX3_V1_FUNC_RANGE_IF +#define NVTX3_FUNC_RANGE_IN NVTX3_V1_FUNC_RANGE_IN +#define NVTX3_FUNC_RANGE_IF_IN NVTX3_V1_FUNC_RANGE_IF_IN /* clang format on */ #endif @@ -2278,8 +2888,18 @@ inline void mark(event_attributes const& attr) noexcept #undef NVTX3_NAMESPACE_FOR #undef NVTX3_VERSION_NAMESPACE #undef NVTX3_INLINE_IF_REQUESTED -#undef NVTX3_RELAXED_CONSTEXPR +#undef NVTX3_CONSTEXPR_IF_CPP14 #if defined(NVTX3_INLINE_THIS_VERSION) #undef NVTX3_INLINE_THIS_VERSION #endif + +#if defined(NVTX3_USE_CHECKED_OVERLOADS_FOR_GET_DEFINED_HERE) +#undef NVTX3_USE_CHECKED_OVERLOADS_FOR_GET_DEFINED_HERE +#undef NVTX3_USE_CHECKED_OVERLOADS_FOR_GET +#endif + +#if defined(NVTX3_STATIC_ASSERT_DEFINED_HERE) +#undef NVTX3_STATIC_ASSERT_DEFINED_HERE +#undef NVTX3_STATIC_ASSERT +#endif diff --git a/src/include/nvtx3/nvtxDetail/nvtxImpl.h b/src/include/nvtx3/nvtxDetail/nvtxImpl.h index be27f4394a..590ce90243 100644 --- a/src/include/nvtx3/nvtxDetail/nvtxImpl.h +++ b/src/include/nvtx3/nvtxDetail/nvtxImpl.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3/nvtxDetail/nvtxImplCore.h b/src/include/nvtx3/nvtxDetail/nvtxImplCore.h index 9f014ca7f5..7a48aa8505 100644 --- a/src/include/nvtx3/nvtxDetail/nvtxImplCore.h +++ b/src/include/nvtx3/nvtxDetail/nvtxImplCore.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3/nvtxDetail/nvtxImplCudaRt_v3.h b/src/include/nvtx3/nvtxDetail/nvtxImplCudaRt_v3.h index d4c0cdf7eb..156f15a6e2 100644 --- a/src/include/nvtx3/nvtxDetail/nvtxImplCudaRt_v3.h +++ b/src/include/nvtx3/nvtxDetail/nvtxImplCudaRt_v3.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3/nvtxDetail/nvtxImplCuda_v3.h b/src/include/nvtx3/nvtxDetail/nvtxImplCuda_v3.h index 4b5d6c7b42..5a379b10fc 100644 --- a/src/include/nvtx3/nvtxDetail/nvtxImplCuda_v3.h +++ b/src/include/nvtx3/nvtxDetail/nvtxImplCuda_v3.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3/nvtxDetail/nvtxImplOpenCL_v3.h b/src/include/nvtx3/nvtxDetail/nvtxImplOpenCL_v3.h index 4a026f024a..bd8d404a5f 100644 --- a/src/include/nvtx3/nvtxDetail/nvtxImplOpenCL_v3.h +++ b/src/include/nvtx3/nvtxDetail/nvtxImplOpenCL_v3.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3/nvtxDetail/nvtxImplSync_v3.h b/src/include/nvtx3/nvtxDetail/nvtxImplSync_v3.h index 90616da322..686686cf94 100644 --- a/src/include/nvtx3/nvtxDetail/nvtxImplSync_v3.h +++ b/src/include/nvtx3/nvtxDetail/nvtxImplSync_v3.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3/nvtxDetail/nvtxInit.h b/src/include/nvtx3/nvtxDetail/nvtxInit.h index 44dcc0ff7a..43cad70105 100644 --- a/src/include/nvtx3/nvtxDetail/nvtxInit.h +++ b/src/include/nvtx3/nvtxDetail/nvtxInit.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3/nvtxDetail/nvtxInitDecls.h b/src/include/nvtx3/nvtxDetail/nvtxInitDecls.h index 261681b792..a52e278faa 100644 --- a/src/include/nvtx3/nvtxDetail/nvtxInitDecls.h +++ b/src/include/nvtx3/nvtxDetail/nvtxInitDecls.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3/nvtxDetail/nvtxInitDefs.h b/src/include/nvtx3/nvtxDetail/nvtxInitDefs.h index ded156c7a1..a670d96e68 100644 --- a/src/include/nvtx3/nvtxDetail/nvtxInitDefs.h +++ b/src/include/nvtx3/nvtxDetail/nvtxInitDefs.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3/nvtxDetail/nvtxLinkOnce.h b/src/include/nvtx3/nvtxDetail/nvtxLinkOnce.h index 908ce88b10..57661c7541 100644 --- a/src/include/nvtx3/nvtxDetail/nvtxLinkOnce.h +++ b/src/include/nvtx3/nvtxDetail/nvtxLinkOnce.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3/nvtxDetail/nvtxTypes.h b/src/include/nvtx3/nvtxDetail/nvtxTypes.h index 53c6c00727..f646b54460 100644 --- a/src/include/nvtx3/nvtxDetail/nvtxTypes.h +++ b/src/include/nvtx3/nvtxDetail/nvtxTypes.h @@ -1,5 +1,5 @@ /* -* Copyright 2009-2020 NVIDIA Corporation. All rights reserved. +* Copyright 2009-2022 NVIDIA Corporation. All rights reserved. * * Licensed under the Apache License v2.0 with LLVM Exceptions. * See https://llvm.org/LICENSE.txt for license information. diff --git a/src/include/nvtx3/nvtxExtDetail/nvtxExtImplPayload_v1.h b/src/include/nvtx3/nvtxExtDetail/nvtxExtImplPayload_v1.h index d589f639f9..4663fda824 100644 --- a/src/include/nvtx3/nvtxExtDetail/nvtxExtImplPayload_v1.h +++ b/src/include/nvtx3/nvtxExtDetail/nvtxExtImplPayload_v1.h @@ -35,10 +35,11 @@ NVTX_EXT_PAYLOAD_VERSIONED_ID(nvtxExtPayloadSlots)[NVTX3EXT_CBID_PAYLOAD_FN_NUM NVTX_LINKONCE_DEFINE_FUNCTION void NVTX_EXT_PAYLOAD_VERSIONED_ID(nvtxExtPayloadInitOnce)() { + intptr_t* fnSlots = NVTX_EXT_PAYLOAD_VERSIONED_ID(nvtxExtPayloadSlots) + 1; nvtxExtModuleSegment_t segment = { 0, // unused (only one segment) NVTX3EXT_CBID_PAYLOAD_FN_NUM, - NVTX_EXT_PAYLOAD_VERSIONED_ID(nvtxExtPayloadSlots) + 1 + fnSlots }; nvtxExtModuleInfo_t module = { diff --git a/src/include/proxy.h b/src/include/proxy.h index fc99135a3a..83b8937861 100644 --- a/src/include/proxy.h +++ b/src/include/proxy.h @@ -12,6 +12,7 @@ #include "devcomm.h" #include "info.h" #include "socket.h" +#include "ipcsocket.h" #include #include "shm.h" @@ -171,6 +172,31 @@ struct ncclProxyProgressState { int nextOps; }; +// Expected proxy response fifo +struct ncclExpectedProxyResponse { + void* opId; + int respSize; + bool done; + void* respBuff; + struct ncclExpectedProxyResponse* next; +}; + +struct ncclProxyAsyncOp { + int type; + struct ncclProxyConnection* connection; + int reqSize, respSize; + char *reqBuff, *respBuff; + void* opId; + ncclProxyAsyncOp* next; +}; + +struct ncclProxyLocalPeer { + struct ncclSocket sock; + int localRank; + ncclProxyAsyncOp* asyncOps; + int asyncOpCounter; +}; + struct ncclProxyState { // Service thread pthread_t thread; @@ -186,6 +212,9 @@ struct ncclProxyState { // Progress thread struct ncclProxyProgressState progressState; + + // Queue of expected responses from the proxy + struct ncclExpectedProxyResponse* expectedResponses; }; enum proxyConnectState { @@ -230,10 +259,19 @@ enum ncclProxyMsgType { ncclProxyMsgStart = 5, ncclProxyMsgClose = 6, ncclProxyMsgAbort = 7, - ncclProxyMsgStop = 8 + ncclProxyMsgStop = 8, + ncclProxyMsgConvertFd = 9 // cuMem API support }; -ncclResult_t ncclProxyCall(struct ncclProxyConnector* proxyConn, int type, void* reqBuff, int reqSize, void* respBuff, int respSize); +// This function is called by a client of the proxy that needs to invoke any of the non-progress proxyOp types +// Call this function on the client, supplying a locally unique opId. Then, poll on the return value of +// ncclPollProxyResponse(), supplying the same opId to confirm the operation has completed +ncclResult_t ncclProxyCallAsync(struct ncclProxyConnector* proxyConn, int type, void* reqBuff, int reqSize, int respSize, void* opId); + +// This function will internally call ncclProxyCallAsync() and spin until ncclPollProxyResponse() confirms the result is received +ncclResult_t ncclProxyCallBlocking(struct ncclProxyConnector* proxyConn, int type, void* reqBuff, int reqSize, void* respBuff, int respSize); +ncclResult_t ncclPollProxyResponse(struct ncclProxyConnector* proxyConn, void* respBuff, void* opId); + ncclResult_t ncclProxyDestroy(struct ncclComm* comm); ncclResult_t ncclProxyShmUnlink(struct ncclComm* comm); diff --git a/src/include/rocmwrap.h b/src/include/rocmwrap.h index 75feeb4504..aeeed0ab40 100644 --- a/src/include/rocmwrap.h +++ b/src/include/rocmwrap.h @@ -70,4 +70,6 @@ DECLARE_ROCM_PFN_EXTERN(hsa_status_string); ncclResult_t rocmLibraryInit(void); +extern bool ncclCudaLaunchBlocking; // initialized by ncclCudaLibraryInit() + #endif diff --git a/src/include/socket.h b/src/include/socket.h index f695267b29..9f271798bc 100644 --- a/src/include/socket.h +++ b/src/include/socket.h @@ -92,6 +92,6 @@ ncclResult_t ncclSocketProgress(int op, struct ncclSocket* sock, void* ptr, int ncclResult_t ncclSocketWait(int op, struct ncclSocket* sock, void* ptr, int size, int* offset); ncclResult_t ncclSocketSend(struct ncclSocket* sock, void* ptr, int size); ncclResult_t ncclSocketRecv(struct ncclSocket* sock, void* ptr, int size); -ncclResult_t ncclSocketTryRecv(struct ncclSocket* sock, void* ptr, int size, int* closed); +ncclResult_t ncclSocketTryRecv(struct ncclSocket* sock, void* ptr, int size, int* closed, bool blocking); ncclResult_t ncclSocketClose(struct ncclSocket* sock); #endif diff --git a/src/include/transport.h b/src/include/transport.h index e957423729..f43464dec9 100644 --- a/src/include/transport.h +++ b/src/include/transport.h @@ -65,7 +65,7 @@ struct ncclTransportComm { }; struct ncclTransport { - const char name[4]; + const char name[8]; ncclResult_t (*canConnect)(int*, struct ncclTopoSystem* topo, struct ncclTopoGraph* graph, struct ncclPeerInfo*, struct ncclPeerInfo*); struct ncclTransportComm send; struct ncclTransportComm recv; @@ -74,6 +74,9 @@ struct ncclTransport { ncclResult_t ncclTransportP2pConnect(struct ncclComm* comm, int channelId, int nrecv, int* peerRecv, int nsend, int* peerSend, int connIndex); ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, int connIndex, int* highestTransportType=NULL); +ncclResult_t ncclNvlsSetup(struct ncclComm* comm); +ncclResult_t ncclNvlsFree(struct ncclComm* comm); + enum { collNetRecv=0, collNetSend=1 }; int ncclTransportCollNetSetup(struct ncclComm* comm, struct ncclTopoGraph* collNetGraph, struct ncclChannel* channel, int masterRank, int masterPeer, int collNetGraphChannelId, int type); ncclResult_t ncclTransportCollNetCheck(struct ncclComm* comm, int collNetSetupFail); diff --git a/src/init.cc b/src/init.cc index 1f4b47ac61..41f3db9f64 100644 --- a/src/init.cc +++ b/src/init.cc @@ -53,7 +53,7 @@ #endif const char* ncclFuncStr[NCCL_NUM_FUNCTIONS+2] = { "Broadcast", "Reduce", "AllGather", "ReduceScatter", "AllReduce", "SendRecv", "AllToAllPivot" }; -const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS] = { "Tree", "Ring", "CollNetDirect", "CollNetChain" }; +const char* ncclAlgoStr[NCCL_NUM_ALGORITHMS] = { "Tree", "Ring", "CollNetDirect", "CollNetChain", "NVLS" }; const char* ncclProtoStr[NCCL_NUM_PROTOCOLS] = { "LL", "LL128", "Simple" }; const char* ncclDevRedOpStr[ncclNumDevRedOps] = { "Sum", "Prod", "Max", "Min", "PreMulSum", "SumPostDiv" }; const char *ncclTypeStr[ncclNumTypes] = {"_i8", "_u8", "_i32", "_u32", "_i64", "_u64", "_f16", "_f32", "_f64", "_b16"}; @@ -61,7 +61,7 @@ const char *ncclTypeStr[ncclNumTypes] = {"_i8", "_u8", "_i32", "_u32", "_i64", " NCCL_PARAM(GroupCudaStream, "GROUP_CUDA_STREAM", NCCL_GROUP_CUDA_STREAM); NCCL_PARAM(CheckPointers, "CHECK_POINTERS", 0); -NCCL_PARAM(CommBlocking, "COMM_BLOCKING", 0); +NCCL_PARAM(CommBlocking, "COMM_BLOCKING", NCCL_CONFIG_UNDEF_INT); struct allocationTracker allocTracker[MAX_ALLOC_TRACK_NGPU] = {}; @@ -89,12 +89,8 @@ ncclResult_t initGdrCopy() { return ncclSuccess; } - -NCCL_PARAM(L1SharedMemoryCarveout, "L1_SHARED_MEMORY_CARVEOUT", 0); - pthread_mutex_t initLock = PTHREAD_MUTEX_INITIALIZER; static bool initialized = false; -static size_t maxLocalSizeBytes = 0; static ncclResult_t ncclInit() { if (__atomic_load_n(&initialized, __ATOMIC_ACQUIRE)) return ncclSuccess; @@ -102,9 +98,6 @@ static ncclResult_t ncclInit() { if (!initialized) { initEnv(); initGdrCopy(); - maxLocalSizeBytes = ncclKernMaxLocalSize(); - int carveout = ncclParamL1SharedMemoryCarveout(); - if (carveout) ncclKernSetSharedMemoryCarveout(carveout); // Always initialize bootstrap network NCCLCHECK(bootstrapNetInit()); NCCLCHECK(ncclNetPluginInit()); @@ -380,6 +373,8 @@ static ncclResult_t commFree(ncclComm_t comm) { NCCLCHECK(ncclStrongStreamDestruct(&comm->deviceStream)); } + if (comm->nvlsSupport) NCCLCHECK(ncclNvlsFree(comm)); + struct ncclDestructor* dtor = comm->destructorHead; while (dtor != nullptr) { NCCLCHECK(dtor->fn(dtor)); @@ -391,6 +386,7 @@ static ncclResult_t commFree(ncclComm_t comm) { ncclMemoryStackDestruct(&comm->memPermanent); ncclCudaHostFree((void *)comm->abortFlag); + free(comm->netName); commPoison(comm); // poison comm before free to avoid comm reuse. free(comm); @@ -418,8 +414,8 @@ static ncclResult_t dmaBufSupported(struct ncclComm* comm) { int flag = 0; CUdevice dev; int cudaDriverVersion; - CUCHECK(cuDriverGetVersion(&cudaDriverVersion)); - if (cudaDriverVersion < 11070) return ncclInternalError; + CUDACHECK(cudaDriverGetVersion(&cudaDriverVersion)); + if (CUPFN(cuDeviceGet) == NULL || cudaDriverVersion < 11070) return ncclInternalError; CUCHECK(cuDeviceGet(&dev, comm->cudaDev)); // Query device to see if DMA-BUF support is available (void) CUPFN(cuDeviceGetAttribute(&flag, CU_DEVICE_ATTRIBUTE_DMA_BUF_SUPPORTED, dev)); @@ -442,7 +438,7 @@ ncclResult_t ncclCommEnsureReady(ncclComm_t comm) { NCCLCHECK(ncclCommGetAsyncError(comm, &ret)); if (ret != ncclSuccess) { /* if ret is not ncclInProgress, we just keep it. */ - WARN("Attempt to use communicator before the previous operation returned ncclSuccess\n"); + WARN("Attempt to use communicator before the previous operation returned ncclSuccess"); if (ret == ncclInProgress) ret = ncclInvalidArgument; goto exit; } @@ -596,6 +592,7 @@ static ncclResult_t devCommSetup(ncclComm_t comm) { tmpCommAndChans.channels[c].collnetChain = comm->channels[c].collnetChain; tmpCommAndChans.channels[c].collnetDirect = comm->channels[c].collnetDirect; tmpCommAndChans.channels[c].binTree = comm->channels[c].binTree; + tmpCommAndChans.channels[c].nvls = comm->channels[c].nvls; tmpCommAndChans.channels[c].workFifoDone = &comm->workFifoDone[c]; if (comm->channels[c].ring.userRanks != nullptr) { @@ -759,8 +756,8 @@ static ncclResult_t collNetTrySetup(ncclComm_t comm, struct ncclTopoGraph* collN struct ncclChannel* channel = comm->channels + c; for (int h = 0; h < nHeads; h++) { const int head = heads[h]; - collNetSetupFail = ncclTransportCollNetSetup(comm, collNetGraph, channel, head, head, h, collNetRecv); - if (!collNetSetupFail) collNetSetupFail = ncclTransportCollNetSetup(comm, collNetGraph, channel, head, head, h, collNetSend); + collNetSetupFail |= ncclTransportCollNetSetup(comm, collNetGraph, channel, head, head, h, collNetRecv); + if (!collNetSetupFail) collNetSetupFail |= ncclTransportCollNetSetup(comm, collNetGraph, channel, head, head, h, collNetSend); } // Verify CollNet setup across ranks after trying the first channel if (c == 0) { @@ -1218,39 +1215,23 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm NCCLCHECKGOTO(ncclCalloc(&rings, nranks*MAXCHANNELS), ret, fail); NCCLCHECKGOTO(ncclTopoPostset(comm, nodesFirstRank, nodesTreePatterns, allTopoRanks, rings, &collNetGraph, nc), ret, fail); - if (comm->topo->pivotA2ANumBiRings == 3) { - NCCLCHECK(ncclTreeBasePostset(comm, &treeGraph)); - if (comm->virtualId == -1) { - NCCLCHECK(ncclBinaryTreeHayabusaPostset(comm, &treeGraph)); - } else { - NCCLCHECK(ncclBinaryTreePostset(comm, &treeGraph)); - } - } + if (comm->topo->pivotA2ANumBiRings == 3) NCCLCHECK(ncclTreeBasePostset(comm, &treeGraph)); // AllGather3 - end TRACE(NCCL_INIT, "rank %d nranks %d - BUILT %d TREES/RINGS", rank, nranks, comm->nChannels); - char line[1024], binline[1024]; + char line[1024]; line[0]='\0'; - binline[0]='\0'; for (int c=0; cnChannels; c++) { struct ncclTree* tree = &comm->channels[c].tree; - struct ncclTree* binTree = &comm->channels[c].binTree; snprintf(line+strlen(line), 1023-strlen(line), " [%d] %d/%d/%d->%d->%d", c, tree->down[0], tree->down[1], tree->down[2], rank, tree->up); - if (comm->topo->pivotA2ANumBiRings == 3) - snprintf(binline+strlen(binline), 1023-strlen(binline), " [%d] %d/%d/%d->%d->%d", - c, binTree->down[0], binTree->down[1], binTree->down[2], rank, binTree->up); INFO(NCCL_GRAPH, "Ring %d : %d -> %d -> %d comm %p nRanks %02d busId %lx", c, comm->channels[c].ring.prev, comm->rank, comm->channels[c].ring.next, comm, comm->nRanks, comm->busId); } line[1023] = '\0'; INFO(NCCL_INIT, "Trees%s comm %p nRanks %02d busId %lx", line, comm, comm->nRanks, comm->busId); - if (comm->topo->pivotA2ANumBiRings == 3) { - binline[1023] = '\0'; - INFO(NCCL_INIT, "BinTrees%s comm %p nRanks %02d busId %lx", binline, comm, comm->nRanks, comm->busId); - } NCCLCHECKGOTO(computeBuffSizes(comm), ret, fail); @@ -1280,11 +1261,6 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm if (comm->nRanks == 1) continue; NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, NCCL_MAX_TREE_ARITY, channel->tree.down, 1, &channel->tree.up, 0), ret, fail); NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, 1, &channel->tree.up, NCCL_MAX_TREE_ARITY, channel->tree.down, 0), ret, fail); - // RCCL: need to connect binTree as well - if (comm->topo->pivotA2ANumBiRings == 3) { - NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, NCCL_MAX_TREE_ARITY, channel->binTree.down, 1, &channel->binTree.up, 0), ret, fail); - NCCLCHECKGOTO(ncclTransportP2pConnect(comm, c, 1, &channel->binTree.up, NCCL_MAX_TREE_ARITY, channel->binTree.down, 0), ret, fail); - } } NCCLCHECKGOTO(ncclTransportP2pSetup(comm, &treeGraph, 0), ret, fail); INFO(NCCL_INIT, "Connected all trees"); @@ -1292,6 +1268,8 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm // Check if we can setup CollNet if (comm->collNetSupport > 0) collNetTrySetup(comm, &collNetGraph); + //NCCLCHECKGOTO(ncclNvlsSetup(comm), ret, fail); + TRACE(NCCL_INIT, "rank %d nranks %d - CONNECTED %d RINGS AND TREES", rank, nranks, comm->nChannels); // Compute time models for algorithm and protocol combinations @@ -1299,7 +1277,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm int myCompCap = comm->peerInfo[rank].cudaCompCap; int minCompCap = myCompCap, maxCompCap = myCompCap; for (int i = 0; i < nranks; i++) { - minCompCap = std::min(comm->peerInfo[i].cudaCompCap, minCompCap); + comm->minCompCap = minCompCap = std::min(comm->peerInfo[i].cudaCompCap, minCompCap); maxCompCap = std::max(comm->peerInfo[i].cudaCompCap, maxCompCap); } NCCLCHECKGOTO(ncclTopoTuneModel(comm, minCompCap, maxCompCap, &treeGraph, &ringGraph, &collNetGraph), ret, fail); @@ -1308,6 +1286,8 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm // Compute nChannels per peer for p2p NCCLCHECKGOTO(ncclTopoComputeP2pChannels(comm), ret, fail); + INFO(NCCL_INIT, "%d coll channels, %d nvls channels, %d p2p channels, %d p2p channels per peer", comm->nChannels, comm->nvlsChannels, comm->p2pnChannels, comm->p2pnChannelsPerPeer); + do { // Setup p2p structures in comm->tasks struct ncclTasks* tasks = &comm->tasks; int nRanks = comm->nRanks; @@ -1374,12 +1354,13 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm } } } + NCCLCHECKGOTO(ncclTransportP2pSetup(comm, NULL, 1), ret, fail); } // Connect to local net proxy NCCLCHECKGOTO(ncclProxyConnect(comm, TRANSPORT_NET, 1, comm->rank, &proxyConn), ret, fail); - NCCLCHECKGOTO(ncclProxyCall(&proxyConn, ncclProxyMsgSharedInit, &comm->p2pnChannels, sizeof(int), NULL, 0), ret, fail); + NCCLCHECKGOTO(ncclProxyCallBlocking(&proxyConn, ncclProxyMsgSharedInit, &comm->p2pnChannels, sizeof(int), NULL, 0), ret, fail); // Then to remote ones when using PXN if (ncclPxnDisable(comm) == 0) { @@ -1387,7 +1368,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm NCCLCHECKGOTO(ncclTopoGetPxnRanks(comm, &pxnPeers, &nranks), ret, fail); for (int r=0; rp2pnChannels, sizeof(int), NULL, 0), ret, fail); + NCCLCHECKGOTO(ncclProxyCallBlocking(&proxyConn, ncclProxyMsgSharedInit, &comm->p2pnChannels, sizeof(int), NULL, 0), ret, fail); } } @@ -1441,6 +1422,11 @@ RCCL_PARAM(StackSizeOverride, "STACK_SIZE_OVERRIDE", 512); NCCL_PARAM(SetStackSize, "SET_STACK_SIZE", 0); RCCL_PARAM(StackSizeOverride, "STACK_SIZE_OVERRIDE", 0); #endif +NCCL_PARAM(CGAClusterSize, "CGA_CLUSTER_SIZE", NCCL_CONFIG_UNDEF_INT); +// Match config max/minCTAs +NCCL_PARAM(MaxCTAs, "MAX_CTAS", NCCL_CONFIG_UNDEF_INT); +NCCL_PARAM(MinCTAs, "MIN_CTAS", NCCL_CONFIG_UNDEF_INT); +#define NCCL_MAX_CGA_CLUSTER_SIZE 8 struct ncclCommInitRankAsyncJob { struct ncclAsyncJob base; @@ -1465,10 +1451,17 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) { int myrank = job->myrank; int cudaDev = job->cudaDev; int virtualId = job->virtualId; + int archMajor, archMinor; + size_t maxLocalSizeBytes = 0; ncclResult_t res = ncclSuccess; int64_t stackSize = rcclParamStackSizeOverride() ? rcclParamStackSizeOverride() : maxLocalSizeBytes; CUDACHECKGOTO(cudaSetDevice(cudaDev), res, fail); + CUDACHECK(cudaDeviceGetAttribute(&archMajor, cudaDevAttrComputeCapabilityMajor, cudaDev)); + CUDACHECK(cudaDeviceGetAttribute(&archMinor, cudaDevAttrComputeCapabilityMinor, cudaDev)); + comm->cudaArch = 100*archMajor + 10*archMinor; + + NCCLCHECK(ncclInitKernelsForDevice(comm->cudaArch, &maxLocalSizeBytes)); // Set the maximum kernel stack size of all kernels to avoid // a CUDA memory reconfig on load (c.f. NVSHMEM issue) #ifdef USE_INDIRECT_FUNCTION_CALL @@ -1487,7 +1480,7 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) { TRACE_CALL("ncclCommInitRank(%p, %d, 0x%llx, %d, %d)", *newcomm, nranks, (unsigned long long)hashUniqueId(commId), myrank, (*newcomm)->cudaDev); - INFO(NCCL_INIT,"comm %p rank %d nranks %d cudaDev %d busId %lx localSize %ld used %ld bytes - Init COMPLETE", *newcomm, myrank, nranks, (*newcomm)->cudaDev, (*newcomm)->busId, ncclKernLocalSize(ncclGetKernelIndex(*newcomm)), allocTracker[(*newcomm)->cudaDev].totalAllocSize); + INFO(NCCL_INIT,"comm %p rank %d nranks %d cudaDev %d busId %lx localSize %zi used %ld bytes - Init COMPLETE", *newcomm, myrank, nranks, (*newcomm)->cudaDev, (*newcomm)->busId, maxLocalSizeBytes, allocTracker[(*newcomm)->cudaDev].totalAllocSize); exit: return res; fail: @@ -1495,18 +1488,143 @@ fail: goto exit; } -static ncclResult_t parseCommConfig(ncclComm_t comm, ncclConfig_t *config) { - ncclResult_t ret = ncclSuccess; - - /* first set configuration */ - if (config) { - comm->blocking = config->blocking; - } else { - /* default setting of communicator */ - comm->blocking = 1; +#define NCCL_CONFIG_DEFAULT(config, field, undef, defvalue, fieldStr, format) \ + if (config->field == undef) { \ + config->field = defvalue; \ + } else { \ + INFO(NCCL_ENV, "Comm config " fieldStr " set to " format, config->field); \ } +static ncclResult_t parseCommConfig(ncclComm_t comm, ncclConfig_t *config) { + ncclResult_t ret = ncclSuccess; + /* config must not be NULL in this function */ + int blockingEnv; + int cgaClusterSizeEnv; + int minCTAsEnv; + int maxCTAsEnv; + const char *envNetName, *tmpNetName; + ncclConfig_t defaultConfig = NCCL_CONFIG_INITIALIZER; + ncclConfig_t internalConfig = NCCL_CONFIG_INITIALIZER; + ncclConfig_t *internalConfigPtr; + size_t realSize; + + internalConfigPtr = &internalConfig; + if (config) { + memcpy((void*)&realSize, (void*)config, sizeof(size_t)); + realSize = realSize > sizeof(ncclConfig_t) ? sizeof(ncclConfig_t) : realSize; + memcpy((void*)internalConfigPtr, (void*)config, realSize); + if (internalConfigPtr->magic != 0xcafebeef) { + WARN("ncclConfig_t argument not initialized via NCCL_CONFIG_INITIALIZER"); + ret = ncclInvalidArgument; + goto fail; + } + + /* check version. */ + if (internalConfigPtr->version < NCCL_VERSION(2, 14, 0)) { + internalConfigPtr->blocking = defaultConfig.blocking; + } + + if (internalConfigPtr->version < NCCL_VERSION(2, 17, 0)) { + internalConfigPtr->cgaClusterSize = defaultConfig.cgaClusterSize; + internalConfigPtr->minCTAs = defaultConfig.minCTAs; + internalConfigPtr->maxCTAs = defaultConfig.maxCTAs; + internalConfigPtr->netName = defaultConfig.netName; + } + } + + /* check input config attributes, -1 means user-undefined and we should use default value from NCCL. */ + if (internalConfigPtr->blocking != NCCL_CONFIG_UNDEF_INT && internalConfigPtr->blocking != 0 && internalConfigPtr->blocking != 1) { + WARN("Invalid config blocking attribute value %d", internalConfigPtr->blocking); + ret = ncclInvalidArgument; + goto fail; + } + + if (internalConfigPtr->cgaClusterSize != NCCL_CONFIG_UNDEF_INT && internalConfigPtr->cgaClusterSize < 0) { + WARN("Invalid config cgaClusterSize attribute value %d", internalConfigPtr->cgaClusterSize); + ret = ncclInvalidArgument; + goto fail; + } + + if ((internalConfigPtr->minCTAs != NCCL_CONFIG_UNDEF_INT && + internalConfigPtr->minCTAs <= 0) || + (internalConfigPtr->maxCTAs != NCCL_CONFIG_UNDEF_INT && + internalConfigPtr->maxCTAs <= 0) || + (internalConfigPtr->minCTAs > internalConfigPtr->maxCTAs)) { + WARN("Invalid config min/max channels attribute value %d/%d", internalConfigPtr->minCTAs, internalConfigPtr->maxCTAs); + ret = ncclInvalidArgument; + goto fail; + } + + /* default config value can be tuned on different platform. */ + NCCL_CONFIG_DEFAULT(internalConfigPtr, blocking, NCCL_CONFIG_UNDEF_INT, 1, "Blocking", "%d"); + NCCL_CONFIG_DEFAULT(internalConfigPtr, cgaClusterSize, NCCL_CONFIG_UNDEF_INT, 4, "CGA cluster size", "%d"); + NCCL_CONFIG_DEFAULT(internalConfigPtr, minCTAs, NCCL_CONFIG_UNDEF_INT, 1, "Min CTAs", "%d"); + NCCL_CONFIG_DEFAULT(internalConfigPtr, maxCTAs, NCCL_CONFIG_UNDEF_INT, MAXCHANNELS, "Max CTAs", "%d"); + NCCL_CONFIG_DEFAULT(internalConfigPtr, netName, NCCL_CONFIG_UNDEF_PTR, NULL, "Net name", "%s"); + + tmpNetName = internalConfigPtr->netName; + + /* assign config to communicator */ + comm->blocking = internalConfigPtr->blocking; + comm->cgaClusterSize = internalConfigPtr->cgaClusterSize; + comm->minCTAs = internalConfigPtr->minCTAs; + comm->maxCTAs = internalConfigPtr->maxCTAs; + + /* override configuration from env variable. */ + blockingEnv = ncclParamCommBlocking(); + if (blockingEnv == 0 || blockingEnv == 1) + comm->blocking = blockingEnv; + + cgaClusterSizeEnv = ncclParamCGAClusterSize(); + if (0 <= cgaClusterSizeEnv && cgaClusterSizeEnv <= NCCL_MAX_CGA_CLUSTER_SIZE) { + comm->cgaClusterSize = cgaClusterSizeEnv; + } else if (cgaClusterSizeEnv > NCCL_MAX_CGA_CLUSTER_SIZE) { + WARN("NCCL_CGA_CLUSTER_SIZE value %d is too big. Limiting value to %d.", cgaClusterSizeEnv, NCCL_MAX_CGA_CLUSTER_SIZE); + comm->cgaClusterSize = NCCL_MAX_CGA_CLUSTER_SIZE; + } + + minCTAsEnv = ncclParamMinCTAs(); + if (minCTAsEnv != NCCL_CONFIG_UNDEF_INT) { + comm->minCTAs = minCTAsEnv; + } + + maxCTAsEnv = ncclParamMaxCTAs(); + if (maxCTAsEnv != NCCL_CONFIG_UNDEF_INT) { + comm->maxCTAs = maxCTAsEnv; + } + + /* cap channels if needed */ + if (comm->minCTAs > MAXCHANNELS) { + WARN("minCTAs %d is larger than #channels upper limit %d", comm->minCTAs, MAXCHANNELS); + comm->minCTAs = MAXCHANNELS; + } + + if (comm->maxCTAs > MAXCHANNELS) { + WARN("maxCTAs %d is larger than #channels upper limit %d", comm->maxCTAs, MAXCHANNELS); + comm->maxCTAs = MAXCHANNELS; + } + + if (comm->minCTAs > comm->maxCTAs) { + WARN("minCTAs %d is larger than maxCTAs %d", comm->minCTAs, comm->maxCTAs); + ret = ncclInvalidArgument; + goto fail; + } + + envNetName = getenv("NCCL_NET"); + if (envNetName) + tmpNetName = envNetName; + if (tmpNetName != NULL) { + int netNameLen = strlen(tmpNetName) + 1; + comm->netName = (char*)malloc(netNameLen); + memcpy(comm->netName, tmpNetName, netNameLen); + } else { + comm->netName = NULL; + } + +exit: return ret; +fail: + goto exit; } static void ncclCommInitRankUndo(struct ncclAsyncJob* job_) { @@ -1533,6 +1651,7 @@ static ncclResult_t ncclCommInitRankDev(ncclComm_t* newcomm, int nranks, ncclUni CUDACHECKGOTO(cudaFree(NULL), res, fail); NCCLCHECKGOTO(PtrCheck(newcomm, "CommInitRank", "newcomm"), res, fail); + NCCLCHECKGOTO(PtrCheck(config, "CommInitRank", "config"), res, fail); if (nranks < 1 || myrank < 0 || myrank >= nranks) { WARN("Invalid rank requested : %d/%d", myrank, nranks); res = ncclInvalidArgument; @@ -1584,12 +1703,13 @@ ncclResult_t ncclCommInitRank(ncclComm_t* newcomm, int nranks, ncclUniqueId comm if (ncclParamDmaBufEnable()) rocmLibraryInit(); int cudaDev; + ncclConfig_t config = NCCL_CONFIG_INITIALIZER; CUDACHECK(cudaGetDevice(&cudaDev)); NvtxParamsCommInitRank payload{myrank, nranks, cudaDev}; NVTX3_FUNC_WITH_PARAMS(CommInitRank, CommInitRankSchema, payload) - NCCLCHECK(ncclCommInitRankDev(newcomm, nranks, commId, myrank, cudaDev, NULL, -1)); + NCCLCHECK(ncclCommInitRankDev(newcomm, nranks, commId, myrank, cudaDev, &config, -1)); return ncclSuccess; } @@ -1599,12 +1719,13 @@ ncclResult_t ncclCommInitRankMulti(ncclComm_t* newcomm, int nranks, ncclUniqueId if (ncclParamDmaBufEnable()) rocmLibraryInit(); int cudaDev; + ncclConfig_t config = NCCL_CONFIG_INITIALIZER; CUDACHECK(hipGetDevice(&cudaDev)); NvtxParamsCommInitRank payload{myrank, nranks, cudaDev}; NVTX3_FUNC_WITH_PARAMS(CommInitRank, CommInitRankSchema, payload) - NCCLCHECK(ncclCommInitRankDev(newcomm, nranks, commId, myrank, cudaDev, NULL, virtualId)); + NCCLCHECK(ncclCommInitRankDev(newcomm, nranks, commId, myrank, cudaDev, &config, virtualId)); return ncclSuccess; } @@ -1614,6 +1735,7 @@ ncclResult_t ncclCommInitAll(ncclComm_t* comms, int ndev, const int* devlist) { ncclResult_t ret = ncclSuccess; int totalnDev; int *gpuFlags = NULL; + ncclConfig_t config = NCCL_CONFIG_INITIALIZER; constexpr nvtxPayloadSchemaEntry_t CommInitAllSchema[] = { {0, NVTX_PAYLOAD_ENTRY_TYPE_INT, "No. of devices"} @@ -1657,7 +1779,7 @@ ncclResult_t ncclCommInitAll(ncclComm_t* comms, int ndev, const int* devlist) { NCCLCHECKGOTO(ncclGroupStart(), ret, fail); for (int i=0; i sizeof(ncclConfig_t) ? sizeof(ncclConfig_t) : realSize; - memcpy((void*)internalConfigPtr, (void*)config, realSize); - if (internalConfigPtr->magic != 0xcafebeef) { - WARN("ncclConfig_t argument not initialized via NCCL_CONFIG_INITIALIZER"); - ret = ncclInvalidArgument; - goto exit; - } - } - - /* check input config attributes */ - if (internalConfigPtr->blocking != 0 && internalConfigPtr->blocking != 1) { - WARN("Invalid config blocking attribute value %d", internalConfigPtr->blocking); - ret = ncclInvalidArgument; - goto exit; - } - - /* overwrite configuration from env variable. */ - blockingEnv = ncclParamCommBlocking(); - if (blockingEnv != 0 && blockingEnv != 1) { - WARN("Invalid NCCL_COMM_BLOCKING value %d", blockingEnv); - } - if (blockingEnv == 1) internalConfigPtr->blocking = blockingEnv; if (ncclParamDmaBufEnable()) (void) rocmLibraryInit(); - CUDACHECKGOTO(cudaGetDevice(&cudaDev), ret, exit); + CUDACHECKGOTO(cudaGetDevice(&cudaDev), ret, fail); + + if (config == NULL) + internalConfigPtr = &internalConfig; + else + internalConfigPtr = config; NCCLCHECKGOTO(ncclCommInitRankDev(newcomm, nranks, commId, myrank, cudaDev, internalConfigPtr, -1), ret, fail); exit: diff --git a/src/misc/cudawrap.cc b/src/misc/cudawrap.cc index e2c1a6ff22..4fe90237ce 100644 --- a/src/misc/cudawrap.cc +++ b/src/misc/cudawrap.cc @@ -23,11 +23,33 @@ DECLARE_CUDA_PFN(cuMemGetAddressRange, 3020); /* proxy.cc */ DECLARE_CUDA_PFN(cuCtxCreate, 3020); DECLARE_CUDA_PFN(cuCtxDestroy, 4000); +DECLARE_CUDA_PFN(cuCtxGetCurrent, 4000); DECLARE_CUDA_PFN(cuCtxSetCurrent, 4000); +DECLARE_CUDA_PFN(cuCtxGetDevice, 2000); +/* cuMem API support */ +DECLARE_CUDA_PFN(cuMemAddressReserve, 10020); +DECLARE_CUDA_PFN(cuMemAddressFree, 10020); +DECLARE_CUDA_PFN(cuMemCreate, 10020); +DECLARE_CUDA_PFN(cuMemGetAllocationGranularity, 10020); +DECLARE_CUDA_PFN(cuMemExportToShareableHandle, 10020); +DECLARE_CUDA_PFN(cuMemImportFromShareableHandle, 10020); +DECLARE_CUDA_PFN(cuMemMap, 10020); +DECLARE_CUDA_PFN(cuMemRelease, 10020); +DECLARE_CUDA_PFN(cuMemSetAccess, 10020); +DECLARE_CUDA_PFN(cuMemUnmap, 10020); #if CUDA_VERSION >= 11070 /* transport/collNet.cc/net.cc*/ DECLARE_CUDA_PFN(cuMemGetHandleForAddressRange, 11070); // DMA-BUF support #endif +#if CUDA_VERSION >= 12010 +/* NVSwitch Multicast support */ +DECLARE_CUDA_PFN(cuMulticastAddDevice, 12010); +DECLARE_CUDA_PFN(cuMulticastBindMem, 12010); +DECLARE_CUDA_PFN(cuMulticastBindAddr, 12010); +DECLARE_CUDA_PFN(cuMulticastCreate, 12010); +DECLARE_CUDA_PFN(cuMulticastGetGranularity, 12010); +DECLARE_CUDA_PFN(cuMulticastUnbind, 12010); +#endif #endif /* CUDA Driver functions loaded with dlsym() */ @@ -39,6 +61,7 @@ DECLARE_CUDA_PFN(cuGetProcAddress, 11030); static void *cudaLib; int ncclCudaDriverVersionCache = -1; +bool ncclCudaLaunchBlocking = false; #if CUDART_VERSION >= 11030 /* @@ -62,9 +85,33 @@ static ncclResult_t cudaPfnFuncLoader(void) { LOAD_SYM(cuMemGetAddressRange, 3020, 1); LOAD_SYM(cuCtxCreate, 3020, 1); LOAD_SYM(cuCtxDestroy, 4000, 1); + LOAD_SYM(cuCtxGetCurrent, 4000, 1); LOAD_SYM(cuCtxSetCurrent, 4000, 1); + LOAD_SYM(cuCtxGetDevice, 2000, 1); +/* cuMem API support */ +#if CUDA_VERSION >= 11030 + LOAD_SYM(cuMemAddressReserve, 10020, 1); + LOAD_SYM(cuMemAddressFree, 10020, 1); + LOAD_SYM(cuMemCreate, 10020, 1); + LOAD_SYM(cuMemGetAllocationGranularity, 10020, 1); + LOAD_SYM(cuMemExportToShareableHandle, 10020, 1); + LOAD_SYM(cuMemImportFromShareableHandle, 10020, 1); + LOAD_SYM(cuMemMap, 10020, 1); + LOAD_SYM(cuMemRelease, 10020, 1); + LOAD_SYM(cuMemSetAccess, 10020, 1); + LOAD_SYM(cuMemUnmap, 10020, 1); +#endif #if CUDA_VERSION >= 11070 LOAD_SYM(cuMemGetHandleForAddressRange, 11070, 1); // DMA-BUF support +#endif +#if CUDA_VERSION >= 12010 +/* NVSwitch Multicast support */ + LOAD_SYM(cuMulticastAddDevice, 12010, 1); + LOAD_SYM(cuMulticastBindMem, 12010, 1); + LOAD_SYM(cuMulticastBindAddr, 12010, 1); + LOAD_SYM(cuMulticastCreate, 12010, 1); + LOAD_SYM(cuMulticastGetGranularity, 12010, 1); + LOAD_SYM(cuMulticastUnbind, 12010, 1); #endif return ncclSuccess; } @@ -74,6 +121,11 @@ static pthread_once_t initOnceControl = PTHREAD_ONCE_INIT; static ncclResult_t initResult; static void initOnceFunc() { + do { + char* val = getenv("CUDA_LAUNCH_BLOCKING"); + ncclCudaLaunchBlocking = val!=nullptr && val[0]!=0 && !(val[0]=='0' && val[1]==0); + } while (0); + CUresult res; /* * Load CUDA driver library @@ -85,9 +137,10 @@ static void initOnceFunc() { else snprintf(path, 1024, "%s%s", ncclCudaPath, "libcuda.so"); + (void) dlerror(); // Clear any previous errors cudaLib = dlopen(path, RTLD_LAZY); if (cudaLib == NULL) { - WARN("Failed to find CUDA library (NCCL_CUDA_PATH='%s') : %s", ncclCudaPath ? ncclCudaPath : "", dlerror()); + WARN("Failed to find CUDA library %s (NCCL_CUDA_PATH='%s') : %s", path, ncclCudaPath ? ncclCudaPath : "", dlerror()); goto error; } diff --git a/src/misc/ipcsocket.cc b/src/misc/ipcsocket.cc new file mode 100644 index 0000000000..b2dee4852d --- /dev/null +++ b/src/misc/ipcsocket.cc @@ -0,0 +1,200 @@ +/* + * Copyright (c) 2016-2023, NVIDIA CORPORATION. All rights reserved. + * + * See COPYRIGHT for license information + */ + +#include "ipcsocket.h" +#include "utils.h" +#include +#include +#include + +// Enable Linux abstract socket naming +#define USE_ABSTRACT_SOCKET + +#define NCCL_IPC_SOCKNAME_STR "/tmp/nccl-socket-%d-%lx" + +/* + * Create a Unix Domain Socket + */ +ncclResult_t ncclIpcSocketInit(ncclIpcSocket *handle, int rank, uint64_t hash, volatile uint32_t* abortFlag) { + int fd = -1; + struct sockaddr_un cliaddr; + char temp[NCCL_IPC_SOCKNAME_LEN] = ""; + + if (handle == NULL) { + return ncclInternalError; + } + + handle->fd = -1; + handle->socketName[0] = '\0'; + if ((fd = socket(AF_UNIX, SOCK_DGRAM, 0)) < 0) { + WARN("UDS: Socket creation error : %d", errno); + return ncclSystemError; + } + + bzero(&cliaddr, sizeof(cliaddr)); + cliaddr.sun_family = AF_UNIX; + + // Create unique name for the socket. + int len = snprintf(temp, NCCL_IPC_SOCKNAME_LEN, NCCL_IPC_SOCKNAME_STR, rank, hash); + if (len > (sizeof(cliaddr.sun_path) - 1)) { + WARN("UDS: Cannot bind provided name to socket. Name too large"); + return ncclInternalError; + } +#ifndef USE_ABSTRACT_SOCKET + unlink(temp); +#endif + + TRACE(NCCL_INIT, "UDS: Creating socket %s", temp); + + strncpy(cliaddr.sun_path, temp, len); +#ifdef USE_ABSTRACT_SOCKET + cliaddr.sun_path[0] = '\0'; // Linux abstract socket trick +#endif + if (bind(fd, (struct sockaddr *)&cliaddr, sizeof(cliaddr)) < 0) { + WARN("UDS: Binding to socket %s failed : %d", temp, errno); + close(fd); + return ncclSystemError; + } + + handle->fd = fd; + strcpy(handle->socketName, temp); + + handle->abortFlag = abortFlag; + // Mark socket as non-blocking + if (handle->abortFlag) { + int flags; + EQCHECK(flags = fcntl(fd, F_GETFL), -1); + SYSCHECK(fcntl(fd, F_SETFL, flags | O_NONBLOCK), "fcntl"); + } + + return ncclSuccess; +} + +ncclResult_t ncclIpcSocketClose(ncclIpcSocket *handle) { + if (handle == NULL) { + return ncclInternalError; + } + if (handle->fd <= 0) { + return ncclSuccess; + } +#ifndef USE_ABSTRACT_SOCKET + if (handle->socketName[0] != '\0') { + unlink(handle->socketName); + } +#endif + close(handle->fd); + + return ncclSuccess; +} + +ncclResult_t ncclIpcSocketRecvFd(ncclIpcSocket *handle, int *recvFd) { + struct msghdr msg = {0, 0, 0, 0, 0, 0, 0}; + struct iovec iov[1]; + + // Union to guarantee alignment requirements for control array + union { + struct cmsghdr cm; + char control[CMSG_SPACE(sizeof(int))]; + } control_un; + + struct cmsghdr *cmptr; + char dummy_buffer[1]; + int ret; + + msg.msg_control = control_un.control; + msg.msg_controllen = sizeof(control_un.control); + + iov[0].iov_base = (void *)dummy_buffer; + iov[0].iov_len = sizeof(dummy_buffer); + + msg.msg_iov = iov; + msg.msg_iovlen = 1; + + while ((ret = recvmsg(handle->fd, &msg, 0)) <= 0) { + if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) { + WARN("UDS: Receiving data over socket failed : %d", errno); + return ncclSystemError; + } + if (handle->abortFlag && *handle->abortFlag) return ncclInternalError; + } + + if (((cmptr = CMSG_FIRSTHDR(&msg)) != NULL) && (cmptr->cmsg_len == CMSG_LEN(sizeof(int)))) { + if ((cmptr->cmsg_level != SOL_SOCKET) || (cmptr->cmsg_type != SCM_RIGHTS)) { + WARN("UDS: Receiving data over socket failed"); + return ncclSystemError; + } + + memmove(recvFd, CMSG_DATA(cmptr), sizeof(*recvFd)); + } else { + WARN("UDS: Receiving data over socket %s failed", handle->socketName); + return ncclSystemError; + } + + TRACE(NCCL_INIT|NCCL_P2P, "UDS: Got recvFd %d from socket %s", *recvFd, handle->socketName); + + return ncclSuccess; +} + +ncclResult_t ncclIpcSocketSendFd(ncclIpcSocket *handle, const int sendFd, int rank, uint64_t hash) { + struct msghdr msg; + struct iovec iov[1]; + char temp[NCCL_IPC_SOCKNAME_LEN]; + + union { + struct cmsghdr cm; + char control[CMSG_SPACE(sizeof(int))]; + } control_un; + + struct cmsghdr *cmptr; + struct sockaddr_un cliaddr; + + // Construct client address to send this shareable handle to + bzero(&cliaddr, sizeof(cliaddr)); + cliaddr.sun_family = AF_UNIX; + + int len = snprintf(temp, NCCL_IPC_SOCKNAME_LEN, NCCL_IPC_SOCKNAME_STR, rank, hash); + if (len > (sizeof(cliaddr.sun_path) - 1)) { + WARN("UDS: Cannot connect to provided name for socket. Name too large"); + return ncclInternalError; + } + (void) strncpy(cliaddr.sun_path, temp, len); + + TRACE(NCCL_INIT, "UDS: Sending fd %d to UDS socket %s", sendFd, temp); + +#ifdef USE_ABSTRACT_SOCKET + cliaddr.sun_path[0] = '\0'; // Linux abstract socket trick +#endif + + msg.msg_control = control_un.control; + msg.msg_controllen = sizeof(control_un.control); + + cmptr = CMSG_FIRSTHDR(&msg); + cmptr->cmsg_len = CMSG_LEN(sizeof(int)); + cmptr->cmsg_level = SOL_SOCKET; + cmptr->cmsg_type = SCM_RIGHTS; + + memmove(CMSG_DATA(cmptr), &sendFd, sizeof(sendFd)); + + msg.msg_name = (void *)&cliaddr; + msg.msg_namelen = sizeof(struct sockaddr_un); + + iov[0].iov_base = (void *)""; + iov[0].iov_len = 1; + msg.msg_iov = iov; + msg.msg_iovlen = 1; + msg.msg_flags = 0; + + ssize_t sendResult; + while ((sendResult = sendmsg(handle->fd, &msg, 0)) <= 0) { + if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) { + WARN("UDS: Sending data over socket %s failed : %d", temp, errno); + return ncclSystemError; + } + if (handle->abortFlag && *handle->abortFlag) return ncclInternalError; + } + + return ncclSuccess; +} diff --git a/src/misc/rocmwrap.cc b/src/misc/rocmwrap.cc index e2545ad2c9..1c70aae4f3 100644 --- a/src/misc/rocmwrap.cc +++ b/src/misc/rocmwrap.cc @@ -24,8 +24,14 @@ static enum { hsaUninitialized, hsaInitializing, hsaInitialized, hsaError } hsaS static void *hsaLib; static uint16_t version_major, version_minor; +bool ncclCudaLaunchBlocking = false; ncclResult_t rocmLibraryInit(void) { + do { + char* val = getenv("CUDA_LAUNCH_BLOCKING"); + ncclCudaLaunchBlocking = val!=nullptr && val[0]!=0 && !(val[0]=='0' && val[1]==0); + } while (0); + hsa_status_t res; if (hsaState == hsaInitialized) diff --git a/src/misc/socket.cc b/src/misc/socket.cc index dd10f312c7..6d934c4bd6 100644 --- a/src/misc/socket.cc +++ b/src/misc/socket.cc @@ -51,7 +51,7 @@ static ncclResult_t socketProgressOpt(int op, struct ncclSocket* sock, void* ptr static ncclResult_t socketProgress(int op, struct ncclSocket* sock, void* ptr, int size, int* offset) { int closed; - NCCLCHECK(socketProgressOpt(op, sock, ptr, size, offset, 0, &closed)); + NCCLCHECK(socketProgressOpt(op, sock, ptr, size, offset, 0 /*block*/, &closed)); if (closed) { char line[SOCKET_NAME_MAXLEN+1]; WARN("socketProgress: Connection closed by remote peer %s", ncclSocketToString(&sock->addr, line, 0)); @@ -827,23 +827,47 @@ ncclResult_t ncclSocketRecv(struct ncclSocket* sock, void* ptr, int size) { } // Receive or detect connection closed -ncclResult_t ncclSocketTryRecv(struct ncclSocket* sock, void* ptr, int size, int* closed) { +ncclResult_t ncclSocketTryRecv(struct ncclSocket* sock, void* ptr, int size, int* closed, bool blocking) { int offset = 0; if (sock == NULL) { WARN("ncclSocketTryRecv: pass NULL socket"); return ncclInvalidArgument; } *closed = 0; - while (offset < size) { + // Block until connection closes or nbytes received + if (blocking) { + while (offset < size) { + NCCLCHECK(socketProgressOpt(NCCL_SOCKET_RECV, sock, ptr, size, &offset, 0, closed)); + if (*closed) return ncclSuccess; + } + } else { NCCLCHECK(socketProgressOpt(NCCL_SOCKET_RECV, sock, ptr, size, &offset, 0, closed)); if (*closed) return ncclSuccess; + + // If any bytes were received, block waiting for the rest + if (offset > 0) { + while (offset < size) { + NCCLCHECK(socketProgressOpt(NCCL_SOCKET_RECV, sock, ptr, size, &offset, 0, closed)); + if (*closed) return ncclSuccess; + } + // No bytes were received, return ncclInProgress + } else { + return ncclInProgress; + } } return ncclSuccess; } ncclResult_t ncclSocketClose(struct ncclSocket* sock) { if (sock != NULL) { - if (sock->fd >= 0) close(sock->fd); + if (sock->fd >= 0) { + /* shutdown() is needed to send FIN packet to proxy thread; shutdown() is not affected + * by refcount of fd, but close() is. close() won't close a fd and send FIN packet if + * the fd is duplicated (e.g. fork()). So shutdown() guarantees the correct and graceful + * connection close here. */ + shutdown(sock->fd, SHUT_RDWR); + close(sock->fd); + } sock->state = ncclSocketStateClosed; sock->fd = -1; } diff --git a/src/nccl.h.in b/src/nccl.h.in index 385287bc09..6047b2f21d 100644 --- a/src/nccl.h.in +++ b/src/nccl.h.in @@ -30,7 +30,9 @@ extern "C" { #endif /*! @brief Opaque handle to communicator */ +#include typedef struct ncclComm* ncclComm_t; +#define NCCL_COMM_NULL NULL #define NCCL_UNIQUE_ID_BYTES 128 typedef struct { char internal[NCCL_UNIQUE_ID_BYTES]; } ncclUniqueId; @@ -46,15 +48,22 @@ typedef enum { ncclSuccess = 0, ncclInProgress = 7, ncclNumResults = 8 } ncclResult_t; +#define NCCL_CONFIG_UNDEF_INT INT_MIN +#define NCCL_CONFIG_UNDEF_PTR NULL + /* Communicator configuration. Users can assign value to attributes to specify the * behavior of a communicator. */ -typedef struct ncclConfig_v21400 { +typedef struct ncclConfig_v21700 { /* attributes that users should never touch. */ size_t size; unsigned int magic; unsigned int version; /* attributes that users are able to customize. */ int blocking; + int cgaClusterSize; + int minCTAs; + int maxCTAs; + const char *netName; } ncclConfig_t; /* Config initializer must be assigned to initialize config structure when it is created. @@ -63,7 +72,11 @@ typedef struct ncclConfig_v21400 { sizeof(ncclConfig_t), /* size */ \ 0xcafebeef, /* magic */ \ NCCL_VERSION(NCCL_MAJOR, NCCL_MINOR, NCCL_PATCH), /* version */ \ - 1 /* blocking */ \ + NCCL_CONFIG_UNDEF_INT, /* blocking */ \ + NCCL_CONFIG_UNDEF_INT, /* cgaClusterSize */ \ + NCCL_CONFIG_UNDEF_INT, /* minCTAs */ \ + NCCL_CONFIG_UNDEF_INT, /* maxCTAs */ \ + NCCL_CONFIG_UNDEF_PTR /* netName */ \ } /*! @brief Return the NCCL_VERSION_CODE of the NCCL library in the supplied integer. diff --git a/src/net.cc b/src/net.cc index 8da90a3638..d31a000202 100644 --- a/src/net.cc +++ b/src/net.cc @@ -183,14 +183,8 @@ ncclResult_t ncclNetPluginInit() { } void* netPluginLib = dlopen(ncclNetPluginName, RTLD_NOW | RTLD_LOCAL); if (netPluginLib == nullptr) { - // dlopen does not guarantee to set errno, but dlerror only gives us a - // string, so checking errno doesn't hurt to try to provide a better - // error message - if (errno == ENOENT) { - INFO(NCCL_INIT|NCCL_NET, "NET/Plugin : No plugin found (%s), using internal implementation", ncclNetPluginName); - } else { - INFO(NCCL_INIT|NCCL_NET, "NET/Plugin : Plugin load returned %d : %s.", errno, dlerror()); - } + INFO(NCCL_INIT|NCCL_NET, "NET/Plugin : Plugin load (%s) returned %d : %s", ncclNetPluginName, errno, dlerror()); + INFO(NCCL_INIT|NCCL_NET, "NET/Plugin : No plugin found, using internal implementation"); return ncclSuccess; } @@ -271,9 +265,10 @@ static ncclResult_t collNetGetState(int i, enum ncclNetState* state) { ncclResult_t ncclNetInit(struct ncclComm* comm) { // Initialize main communication network - char* netName = getenv("NCCL_NET"); + char* netName; bool ok = false; + netName = comm->netName; for (int i=0; i<3; i++) { if (ncclNets[i] == nullptr) continue; enum ncclNetState state; @@ -335,9 +330,26 @@ ncclResult_t ncclGpuGdrSupport(struct ncclComm* comm, int* gdrSupport) { ncclResult_t ret; ncclDebugNoWarn = NCCL_NET; NCCLCHECKGOTO(ncclNetListen(comm, dev, &handle, &lComm), ret, cleanup1); - NCCLWAITGOTO(ncclNetConnect(comm, dev, &handle, &sComm), sComm != NULL, comm->abortFlag, ret, cleanup2); - NCCLWAITGOTO(ncclNetAccept(comm, lComm, &rComm), rComm != NULL, comm->abortFlag, ret, cleanup3); - CUDACHECKGOTO(cudaMalloc(&gpuPtr, GPU_BUF_SIZE), ret, cleanup4); + + bool connected; + connected = false; + while (!connected) { + + // If we're aborting now, skip to cleanup + if (*comm->abortFlag) { + goto cleanup2; + } + + if (sComm == NULL) + NCCLCHECKGOTO(ncclNetConnect(comm, dev, &handle, &sComm), ret, cleanup2); + + if (rComm == NULL) + NCCLCHECKGOTO(ncclNetAccept(comm, lComm, &rComm), ret, cleanup2); + + connected = (rComm != NULL) && (sComm != NULL); + } + + CUDACHECKGOTO(cudaMalloc(&gpuPtr, GPU_BUF_SIZE), ret, cleanup2); if (ncclNetRegMr(comm, sComm, gpuPtr, GPU_BUF_SIZE, NCCL_PTR_CUDA, &mHandle) == ncclSuccess) { NCCLCHECK(ncclNetDeregMr(comm, sComm, mHandle)); NCCLCHECK(ncclNetRegMr(comm, rComm, gpuPtr, GPU_BUF_SIZE, NCCL_PTR_CUDA, &mHandle)); @@ -346,11 +358,11 @@ ncclResult_t ncclGpuGdrSupport(struct ncclComm* comm, int* gdrSupport) { } ncclDebugNoWarn = 0; CUDACHECK(cudaFree(gpuPtr)); -cleanup4: - NCCLCHECK(ncclNetCloseRecv(comm, rComm)); -cleanup3: - NCCLCHECK(ncclNetCloseSend(comm, sComm)); cleanup2: + if (rComm != NULL) + NCCLCHECK(ncclNetCloseRecv(comm, rComm)); + if (sComm != NULL) + NCCLCHECK(ncclNetCloseSend(comm, sComm)); NCCLCHECK(ncclNetCloseListen(comm, lComm)); cleanup1: break; diff --git a/src/proxy.cc b/src/proxy.cc index 077a35e981..74551365cd 100644 --- a/src/proxy.cc +++ b/src/proxy.cc @@ -16,6 +16,7 @@ #include "timer.h" #include +#include static bool NeedProxy(int type, int pattern, int root, struct ncclRing* ring, int nranks) { if (pattern == ncclPatternRing || pattern == ncclPatternRingTwice) return true; @@ -37,6 +38,155 @@ struct ncclProxyPool { struct ncclProxyArgs elems[PROXYARGS_ALLOCATE_SIZE]; }; +static void expectedProxyResponseFree(struct ncclProxyState* state) { + struct ncclExpectedProxyResponse* elem = state->expectedResponses; + struct ncclExpectedProxyResponse* prev = NULL; + + while (elem) { + prev = elem; + elem = elem->next; + free(prev->respBuff); + free(prev); + } +} + +static ncclResult_t expectedProxyResponseStore(struct ncclProxyState* state, void* opId, void* respBuff, int respSize) { + struct ncclExpectedProxyResponse* elem = state->expectedResponses; + while (elem) { + if (elem->opId == opId) { + if (respSize != elem->respSize) { + WARN("Mismatched response size for opId=%p", opId); + return ncclInternalError; + } + + if (elem->done) { + WARN("Storing response for already completed opId=%p", opId); + return ncclInternalError; + } + + memcpy(elem->respBuff, respBuff, respSize); + elem->done = true; + return ncclSuccess; + } + elem = elem->next; + } + + WARN("Proxy response for opId=%p doesn't match any expected response", opId); + return ncclInternalError; +} + +static ncclResult_t expectedProxyResponseEnqueue(struct ncclProxyState* state, void* opId, int respSize, void* respData, int respDataSize) { + struct ncclExpectedProxyResponse* ex; + NCCLCHECK(ncclCalloc(&ex, 1)); + ex->opId = opId; + + // Pre-alloc response buffer + ex->respBuff = malloc(respSize); + ex->respSize = respSize; + ex->done = false; + if (respData) { + memcpy(ex->respBuff, respData, respDataSize); + ex->done = true; + } + + // Enqueue + struct ncclExpectedProxyResponse* list = state->expectedResponses; + if (list == NULL) { + state->expectedResponses = ex; + return ncclSuccess; + } + while (list->next) list = list->next; + list->next = ex; + return ncclSuccess; +} + +static ncclResult_t expectedProxyResponseDequeue(struct ncclProxyState* state, void* opId, void* respBuff, int* found) { + struct ncclExpectedProxyResponse* elem = state->expectedResponses; + struct ncclExpectedProxyResponse* prev = NULL; + *found = 0; + while (elem) { + if ((elem->opId == opId) && elem->done) { + if (prev == NULL) { + state->expectedResponses = elem->next; + } else { + prev->next = elem->next; + } + memcpy(respBuff, elem->respBuff, elem->respSize); + free(elem->respBuff); + free(elem); + *found = 1; + return ncclSuccess; + } + prev = elem; + elem = elem->next; + } + return ncclSuccess; +} + +static ncclResult_t expectedProxyResponseRemove(struct ncclProxyState* state, void* opId) { + struct ncclExpectedProxyResponse* elem = state->expectedResponses; + struct ncclExpectedProxyResponse* prev = NULL; + while (elem) { + if (elem->opId == opId) { + if (prev == NULL) { + state->expectedResponses = elem->next; + } else { + prev->next = elem->next; + } + free(elem->respBuff); + free(elem); + return ncclSuccess; + } + prev = elem; + elem = elem->next; + } + WARN("Couldn't find opId=%p", opId); + return ncclInternalError; +} + +static ncclResult_t asyncProxyOpEnqueue(struct ncclProxyLocalPeer* peer, ncclProxyAsyncOp* op) { + ncclProxyAsyncOp* list = peer->asyncOps; + if (list == NULL) { + peer->asyncOps = op; + return ncclSuccess; + } + while (list->next) list = list->next; + list->next = op; + return ncclSuccess; +} + +static ncclResult_t asyncProxyOpDequeue(struct ncclProxyLocalPeer* peer, ncclProxyAsyncOp* op) { + struct ncclProxyAsyncOp* elem = peer->asyncOps; + struct ncclProxyAsyncOp* prev = NULL; + while (elem) { + if (elem->opId == op->opId) { + if (prev == NULL) { + peer->asyncOps = elem->next; + } else { + prev->next = elem->next; + } + + if (elem->reqBuff) { + free(elem->reqBuff); + } + if (elem->respBuff) { + free(elem->respBuff); + } + free(elem); + + return ncclSuccess; + } + prev = elem; + elem = elem->next; + } + if (op) { + WARN("Attempting to dequeue nonexistent async opId=%p", op->opId); + } else { + WARN("Attempting to dequeue null operation"); + } + return ncclInternalError; +} + static ncclResult_t allocateArgs(struct ncclProxyProgressState* state, struct ncclProxyArgs** argsptr) { struct ncclProxyArgs* elem; if (state->pool == NULL) { @@ -86,7 +236,7 @@ ncclResult_t getOpIndex(struct ncclProxyArgs* op, struct ncclProxyProgressState* pool = pool->next; p++; } - WARN("Could not find pool of op %p\n", op); + WARN("Could not find pool of op %p", op); return ncclInternalError; } @@ -140,7 +290,7 @@ ncclResult_t dumpProxyState(struct ncclProxyProgressState* state) { nextOp->state |= OP_SEEN; printf("\n"); if (nextOp->next) { - WARN("Inactive op has next set!\n"); + WARN("Inactive op has next set!"); } nextOp = nextOp->nextPeer; } @@ -337,7 +487,7 @@ ncclResult_t ncclLocalOpAppend(struct ncclComm* comm, struct ncclProxyConnector* } } if (lastOp == -1) { - WARN("Unable to post incomplete proxy op chain %d..%d (opCount %ld)\n", proxyOps->nextOps, proxyOps->nextOpsEnd, lastOpCount); + WARN("Unable to post incomplete proxy op chain %d..%d (opCount %ld)", proxyOps->nextOps, proxyOps->nextOpsEnd, lastOpCount); return ncclInternalError; } // Cut chain at lastOp @@ -775,19 +925,6 @@ ncclResult_t ncclProxyProgressDestroy(struct ncclComm* comm) { return ncclSuccess; } -struct ncclProxyAsyncOp { - int type; - struct ncclProxyConnection* connection; - int reqSize, respSize; - char *reqBuff, *respBuff; -}; - -struct ncclProxyLocalPeer { - struct ncclSocket sock; - int localRank; - struct ncclProxyAsyncOp asyncOps; -}; - #define NCCL_PROXY_CONN_POOL_SIZE_POW2 7 #define NCCL_PROXY_CONN_POOL_SIZE (1<<(NCCL_PROXY_CONN_POOL_SIZE_POW2)) #define NCCL_PROXY_CONN_POOL_MASK ((NCCL_PROXY_CONN_POOL_SIZE)-1) @@ -795,7 +932,6 @@ struct ncclProxyConnectionPool { struct ncclProxyConnection** pools; int banks; int offset; - struct ncclProxyAsyncOp* ops; }; static ncclResult_t ncclProxyNewConnection(struct ncclProxyConnectionPool* pool, int* id) { @@ -893,26 +1029,137 @@ ncclResult_t ncclProxyConnect(struct ncclComm* comm, int transport, int send, in return ncclSuccess; } -const char* ncclProxyMsgTypeStr[] = { "Unknown", "Init", "SharedInit", "Setup", "Connect", "Start", "Close", "Abort", "Stop" }; -ncclResult_t ncclProxyCall(struct ncclProxyConnector* proxyConn, int type, void* reqBuff, int reqSize, void* respBuff, int respSize) { +const char* ncclProxyMsgTypeStr[] = { "Unknown", "Init", "SharedInit", "Setup", "Connect", "Start", "Close", "Abort", "Stop", "ConvertFd" }; +ncclResult_t ncclProxyCallAsync(struct ncclProxyConnector* proxyConn, int type, void* reqBuff, int reqSize, int respSize, void* opId) { struct ncclSocket* sock; ncclResult_t ret = ncclSuccess; + void* respData = NULL; + int respDataSize = 0; + struct ncclComm* comm = proxyConn->comm; + struct ncclIpcSocket ipcSock = { 0 }; - if (proxyConn->comm->proxyState.peerSocks == NULL) return ncclInternalError; - sock = proxyConn->comm->proxyState.peerSocks + proxyConn->localRank; + if (*comm->abortFlag != 0) { + WARN("ncclProxyCallAsync() - Saw abortFlag while waiting for proxyThread response"); + return ncclInternalError; + } + if (comm->proxyState.peerSocks == NULL) return ncclInternalError; + + sock = comm->proxyState.peerSocks + proxyConn->localRank; if (sock == NULL) return ncclInternalError; + + if (type == ncclProxyMsgConvertFd) { + // cuMem API support + // Create a UDS socket to receive the converted fd + NCCLCHECK(ncclIpcSocketInit(&ipcSock, comm->localRank, (uint64_t)proxyConn->connection, comm->abortFlag)); + } + NCCLCHECKGOTO(ncclSocketSend(sock, &type, sizeof(int)), ret, error); NCCLCHECKGOTO(ncclSocketSend(sock, &proxyConn->connection, sizeof(void*)), ret, error); NCCLCHECKGOTO(ncclSocketSend(sock, &reqSize, sizeof(int)), ret, error); NCCLCHECKGOTO(ncclSocketSend(sock, &respSize, sizeof(int)), ret, error); if (reqSize) NCCLCHECKGOTO(ncclSocketSend(sock, reqBuff, reqSize), ret, error); - if (respSize) NCCLCHECKGOTO(ncclSocketRecv(sock, respBuff, respSize), ret, error); + + if (type == ncclProxyMsgConvertFd) { + // cuMem API support + int recvFd = -1; + if (reqSize != sizeof(int) || respSize != sizeof(int)) return ncclInternalError; + // Receive converted fd over UDS + NCCLCHECK(ncclIpcSocketRecvFd(&ipcSock, &recvFd)); + TRACE(NCCL_NET, "UDS: ConvertFd rank %d returned %p %d", proxyConn->localRank, &recvFd, recvFd); + assert(recvFd != -1); + respData = &recvFd; + respDataSize = sizeof(recvFd); + NCCLCHECK(ncclIpcSocketClose(&ipcSock)); + } else { + // Send opId to proxy + NCCLCHECKGOTO(ncclSocketSend(sock, &opId, sizeof(opId)), ret, error); + } + // Add proxyOp to expected response queue + NCCLCHECK(expectedProxyResponseEnqueue(&comm->proxyState, opId, respSize, respData, respDataSize)); + return ncclSuccess; error: - WARN("Proxy Call to rank %d failed (%s)", proxyConn->comm->localRankToRank[proxyConn->localRank], ncclProxyMsgTypeStr[type]); + NCCLCHECK(ncclIpcSocketClose(&ipcSock)); + WARN("Proxy Call to rank %d failed (%s)", comm->localRankToRank[proxyConn->localRank], ncclProxyMsgTypeStr[type]); return ret; } +ncclResult_t ncclPollProxyResponse(struct ncclProxyConnector* proxyConn, void* respBuff, void* opId) { + struct ncclComm* comm = proxyConn->comm; + + // Receive the connection pointer from the Proxy + if (*comm->abortFlag) { + WARN("Comm %p is in abort state", comm); + return ncclInternalError; + } + if (comm->proxyState.peerSocks == NULL) return ncclInternalError; + + // Check response queue + int found = 0; + NCCLCHECK(expectedProxyResponseDequeue(&comm->proxyState, opId, respBuff, &found)); + if (found == 0) { + // Attempt to read in a new response header from the proxy thread + struct ncclSocket* sock = comm->proxyState.peerSocks + proxyConn->localRank; + + void* recvOpId; + int offset = 0; + if (ncclSuccess != ncclSocketProgress(NCCL_SOCKET_RECV, sock, &recvOpId, sizeof(recvOpId), &offset)) { + WARN("Socket recv failed while polling for opId=%p", opId); + return ncclInternalError; + } + + if (offset == 0) { + return ncclInProgress; + // If we've returned a partial response, block to receive the rest of it + } else if (offset < sizeof(recvOpId)) { + while (offset < sizeof(recvOpId)) + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, sock, &recvOpId, sizeof(recvOpId), &offset)); + } + + INFO(NCCL_PROXY, "ncclPollProxyResponse Recieved new opId=%p", recvOpId); + + // Now do a blocking recv of the response size + int respSize = 0; + NCCLCHECK(ncclSocketRecv(sock, &respSize, sizeof(respSize))); + + // If there's a respSize to recv + if (respSize > 0) { + NCCLCHECK(ncclSocketRecv(sock, respBuff, respSize)); + } + + if (recvOpId == opId) { + INFO(NCCL_PROXY, "recvOpId=%p matches expected opId=%p", recvOpId, opId); + NCCLCHECK(expectedProxyResponseRemove(&comm->proxyState, recvOpId)); + return ncclSuccess; + } else { + INFO(NCCL_PROXY, "Queing opId=%p", recvOpId); + // Store the result and mark response as completed + NCCLCHECK(expectedProxyResponseStore(&comm->proxyState, recvOpId, respBuff, respSize)); + return ncclInProgress; + } + } else { + INFO(NCCL_PROXY, "ncclPollProxyResponse Dequeued cached opId=%p", opId); + } + + return ncclSuccess; +} + +ncclResult_t ncclProxyCallBlocking(struct ncclProxyConnector* proxyConn, int type, void* reqBuff, int reqSize, void* respBuff, int respSize) { + // Alloc some memory to act as a handle + void* opId = malloc(1); + + NCCLCHECK(ncclProxyCallAsync(proxyConn, type, reqBuff, reqSize, respSize, opId)); + ncclResult_t res = ncclInProgress; + + while (res == ncclInProgress) { + res = ncclPollProxyResponse(proxyConn, respBuff, opId); + } + + free(opId); + + return res; +} + static ncclResult_t proxyProgressInit(struct ncclComm* comm) { struct ncclProxyProgressState* state = &comm->proxyState.progressState; if (state->opsPool == NULL) { @@ -1003,16 +1250,55 @@ static ncclResult_t proxyConnSharedInit(struct ncclProxyLocalPeer* peer, struct if (reqSize != sizeof(int) || respSize != 0) return ncclInternalError; int nChannels; NCCLCHECK(ncclSocketRecv(sock, &nChannels, sizeof(int))); + + // Store opId for completion response + void* opId; + NCCLCHECK(ncclSocketRecv(sock, &opId, sizeof(opId))); + INFO(NCCL_PROXY, "proxyConnSharedInit received opId=%p", opId); + if (connection->tcomm->proxySharedInit) NCCLCHECK(connection->tcomm->proxySharedInit(connection, comm, nChannels)); __atomic_store_n(&connection->state, connSharedInitialized, __ATOMIC_RELEASE); + + // Send the opId for referencing async operation + INFO(NCCL_PROXY, "proxyConnSharedInit::ncclSocketSend(opId=%p)", opId); + NCCLCHECK(ncclSocketSend(connection->sock, &opId, sizeof(opId))); + + // Send the response size + INFO(NCCL_PROXY, "proxyConnSharedInit::ncclSocketSend(op.respSize=%d)", respSize); + NCCLCHECK(ncclSocketSend(connection->sock, &respSize, sizeof(respSize))); + return ncclSuccess; } -static ncclResult_t proxyProgressAsync(struct ncclProxyAsyncOp* op, struct ncclComm* comm, int* asyncOpCount) { +// cuMem API support +static ncclResult_t proxyConvertFd(struct ncclProxyLocalPeer* peer, struct ncclComm* comm) { + struct ncclSocket* sock = &peer->sock; + uint64_t connection; + NCCLCHECK(ncclSocketRecv(sock, &connection, sizeof(uint64_t))); + int reqSize, respSize; + NCCLCHECK(ncclSocketRecv(sock, &reqSize, sizeof(int))); + NCCLCHECK(ncclSocketRecv(sock, &respSize, sizeof(int))); + if (reqSize != sizeof(int) || respSize != sizeof(int)) return ncclInternalError; + + int fd; + struct ncclIpcSocket ipcSock = { 0 }; + NCCLCHECK(ncclSocketRecv(sock, &fd, sizeof(int))); + + INFO(NCCL_NET, "UDS: proxyConvertFd received fd %d peer %d connection %lx", fd, peer->localRank, connection); + // Send back the converted fd using UDS + NCCLCHECK(ncclIpcSocketInit(&ipcSock, comm->localRank, connection^1, comm->abortFlag)); + NCCLCHECK(ncclIpcSocketSendFd(&ipcSock, fd, peer->localRank, connection)); + NCCLCHECK(ncclIpcSocketClose(&ipcSock)); + return ncclSuccess; +} + +static ncclResult_t proxyProgressAsync(struct ncclProxyAsyncOp* op, struct ncclComm* comm, int* asyncOpCount, struct ncclProxyLocalPeer* peer) { int done = 1; if (op->type == ncclProxyMsgSetup) { + INFO(NCCL_PROXY, "proxyProgressAsync::proxySetup() opId=%p", op->opId); NCCLCHECK(op->connection->tcomm->proxySetup(op->connection, comm, op->reqBuff, op->reqSize, op->respBuff, op->respSize, &done)); } else if (op->type == ncclProxyMsgConnect) { + INFO(NCCL_PROXY, "proxyProgressAsync::proxyConnect() opId=%p op.reqBuff=%p", op->opId, op->reqBuff); NCCLCHECK(op->connection->tcomm->proxyConnect(op->connection, comm, op->reqBuff, op->reqSize, op->respBuff, op->respSize, &done)); } else return ncclInternalError; if (done) { @@ -1020,31 +1306,38 @@ static ncclResult_t proxyProgressAsync(struct ncclProxyAsyncOp* op, struct ncclC __atomic_store_n(&op->connection->state, connSetupDone, __ATOMIC_RELEASE); else if (op->type == ncclProxyMsgConnect) __atomic_store_n(&op->connection->state, connConnected, __ATOMIC_RELEASE); - /* if setup or connect is done, we should not return any error at this point since + /* if setup or connect is done, we should not return any error at this point since * ncclSocketSend might already send the respBuff to the requester. If we still choose * to abort and close the connection, it can cause segfault if the requester is using * the respBuff. */ - if (op->respSize) ncclSocketSend(op->connection->sock, op->respBuff, op->respSize); - if (op->reqBuff) { - free(op->reqBuff); - op->reqBuff = NULL; + + // Send the opId for referencing async operation + NCCLCHECK(ncclSocketSend(op->connection->sock, &op->opId, sizeof(op->opId))); + + // Send the response size + NCCLCHECK(ncclSocketSend(op->connection->sock, &op->respSize, sizeof(op->respSize))); + + if (op->respSize) { + // Send the response + NCCLCHECK(ncclSocketSend(op->connection->sock, op->respBuff, op->respSize)); } - if (op->respBuff) { - free(op->respBuff); - op->respBuff = NULL; - } - op->type = 0; + + asyncProxyOpDequeue(peer, op); (*asyncOpCount)--; + return ncclSuccess; + } else if (*comm->abortFlag != 0) { return ncclInternalError; } - return ncclSuccess; + return ncclInProgress; } static ncclResult_t proxyConnSetupConnect(int type, struct ncclProxyLocalPeer* peer, struct ncclProxyConnectionPool* connectionPool, struct ncclComm* comm, int* asyncOpCount) { struct ncclSocket* sock = &peer->sock; - struct ncclProxyAsyncOp* asyncOp = &peer->asyncOps; + struct ncclProxyAsyncOp* asyncOp; + NCCLCHECK(ncclCalloc(&asyncOp, 1)); + asyncOp->type = type; NCCLCHECK(ncclSocketRecv(sock, &asyncOp->connection, sizeof(void*))); @@ -1054,9 +1347,16 @@ static ncclResult_t proxyConnSetupConnect(int type, struct ncclProxyLocalPeer* p NCCLCHECK(ncclCalloc(&asyncOp->reqBuff, asyncOp->reqSize)); NCCLCHECK(ncclSocketRecv(sock, asyncOp->reqBuff, asyncOp->reqSize)); } + + // Store opId for completion response + NCCLCHECK(ncclSocketRecv(sock, &asyncOp->opId, sizeof(asyncOp->opId))); + if (asyncOp->respSize) NCCLCHECK(ncclCalloc(&asyncOp->respBuff, asyncOp->respSize)); + + asyncProxyOpEnqueue(peer, asyncOp); + (*asyncOpCount)++; - NCCLCHECK(proxyProgressAsync(asyncOp, comm, asyncOpCount)); + NCCLCHECK(proxyProgressAsync(asyncOp, comm, asyncOpCount, peer)); return ncclSuccess; } @@ -1086,7 +1386,7 @@ void* ncclProxyService(void* _args) { pollfds[s].events = POLLHUP|POLLIN; } if (ncclSocketGetFd(comm->proxyState.listenSock, &pollfds[NCCL_MAX_LOCAL_RANKS].fd) != ncclSuccess) { - WARN("[Proxy Service] Get listenSock fd fails\n"); + WARN("[Proxy Service] Get listenSock fd fails"); return NULL; }; pollfds[NCCL_MAX_LOCAL_RANKS].events = POLLIN; @@ -1118,14 +1418,14 @@ void* ncclProxyService(void* _args) { } if (maxnpeers < s+1) maxnpeers = s+1; if (ncclSocketInit(&peers[s].sock) != ncclSuccess) { - WARN("[Service thread] Initialize peers[%d].sock fails\n", s); + WARN("[Service thread] Initialize peers[%d].sock fails", s); return NULL; } if (ncclSocketAccept(&peers[s].sock, comm->proxyState.listenSock) != ncclSuccess) { WARN("[Service thread] Accept failed %s", strerror(errno)); } else { if (ncclSocketGetFd(&peers[s].sock, &pollfds[s].fd) != ncclSuccess) { - WARN("[Service thread] Get peers[%d].sock fd fails\n", s); + WARN("[Service thread] Get peers[%d].sock fd fails", s); return NULL; } npeers++; @@ -1135,25 +1435,37 @@ void* ncclProxyService(void* _args) { for (int s=0; ssock; - struct ncclProxyAsyncOp* op = &peer->asyncOps; int closeConn = 0; int type = 0; ncclResult_t res = ncclSuccess; - if (pollfds[s].fd == -1) continue; - if (op->type != 0) { - res = proxyProgressAsync(op, comm, &asyncOpCount); + + // Progress all ops for this ncclProxyLocalPeer + ncclProxyAsyncOp* op = peer->asyncOps; + while (op != nullptr) { type = op->type; - if (res != ncclSuccess) closeConn = 1; - } else if (pollfds[s].revents & POLLIN) { + res = proxyProgressAsync(op, comm, &asyncOpCount, peer); + if (res == ncclSuccess || res == ncclInProgress) { + op = op->next; + } else { + // Res is a bad result + closeConn = 1; + WARN("[Service thread] Error encountered progressing operation=%s, res=%d, closing connection", ncclProxyMsgTypeStr[type], res); + break; + } + } + + // Check for additional ops coming in + if (pollfds[s].revents & POLLIN) { int closed; - if (ncclSocketTryRecv(sock, &type, sizeof(int), &closed) != ncclSuccess) { - WARN("[Service thread] Could not receive type from localRank %d", peer->localRank); + res = ncclSocketTryRecv(sock, &type, sizeof(int), &closed, false /*blocking*/); + if (res != ncclSuccess && res != ncclInProgress) { + WARN("[Service thread] Could not receive type from localRank %d, res=%u, closed=%d", peer->localRank, res, closed); closeConn = 1; } else if (closed) { INFO(NCCL_INIT|NCCL_NET, "[Service thread] Connection closed by localRank %d", peer->localRank); closeConn = 1; - } else { + } else if (res == ncclSuccess) { // We received something from the sock if (type == ncclProxyMsgStop) { stop = 1; closeConn = 1; @@ -1164,30 +1476,32 @@ void* ncclProxyService(void* _args) { } else if (type == ncclProxyMsgSharedInit) { res = proxyConnSharedInit(peers+s, &connectionPool, comm); } else if (type == ncclProxyMsgSetup || type == ncclProxyMsgConnect) { + INFO(NCCL_PROXY, "proxyConnSetupConnect for peer->localRank %d,", peer->localRank); res = proxyConnSetupConnect(type, peers+s, &connectionPool, comm, &asyncOpCount); + } else if (type == ncclProxyMsgConvertFd) { + res = proxyConvertFd(peers+s, comm); // cuMem API support } else { - WARN("[Service thread] Unknown command %d from localRank %d\n", type, peer->localRank); + WARN("[Service thread] Unknown command %d from localRank %d", type, peer->localRank); closeConn = 1; } + + INFO(NCCL_PROXY, "Received and initiated operation=%s res=%d", ncclProxyMsgTypeStr[type], res); } } else if (pollfds[s].revents & POLLHUP) { closeConn = 1; - } - if (res != ncclSuccess) { + } + if (res != ncclSuccess && res != ncclInProgress) { WARN("[Proxy Service %d] Failed to execute operation %s from rank %d, retcode %d", comm->rank, ncclProxyMsgTypeStr[type], comm->localRankToRank[peer->localRank], res); closeConn = 1; } + if (closeConn) { ncclSocketClose(sock); - if (op->reqBuff) { - free(op->reqBuff); - op->reqBuff = NULL; + + if (op != nullptr) { + asyncProxyOpDequeue(peer, op); + asyncOpCount--; } - if (op->respBuff) { - free(op->respBuff); - op->respBuff = NULL; - } - op->type = 0; pollfds[s].fd = -1; npeers--; } @@ -1255,6 +1569,7 @@ ncclResult_t ncclProxyDestroy(struct ncclComm* comm) { free(state->peerSocks); free(state->proxyOps); free(state->sharedDevMems); + expectedProxyResponseFree(state); } return ncclSuccess; } diff --git a/src/transport.cc b/src/transport.cc index 48ae3e02d7..7a2502bf3e 100644 --- a/src/transport.cc +++ b/src/transport.cc @@ -82,9 +82,12 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* // Stream used during transport setup; need for P2P pre-connect + CUDA Graph ncclResult_t ret = ncclSuccess; int highestType = TRANSPORT_P2P; // track highest transport type - struct ncclConnect data[2*MAXCHANNELS]; + struct ncclConnect** data = (ncclConnect**) malloc(sizeof(ncclConnect*) * comm->nRanks); // Store intermediate send/recvData structs for connect + struct ncclConnect** recvData = (ncclConnect**) malloc(sizeof(ncclConnect*) * comm->nRanks); // Points to entries inside data for given recv connection within a channel + struct ncclConnect** sendData = (ncclConnect**) malloc(sizeof(ncclConnect*) * comm->nRanks); // Points to entries inside data for given send connection within a channel NCCLCHECKGOTO(ncclStrongStreamAcquireUncaptured(&comm->hostStream), ret, fail); + // First time initialization for (int i=1; inRanks; i++) { int bootstrapTag = (i<<8) + (graph ? graph->id+1 : 0); int recvPeer = (comm->rank - i + comm->nRanks) % comm->nRanks; @@ -92,22 +95,28 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* uint64_t recvMask = comm->connectRecv[recvPeer+comm->nRanks*(connIndex == NCCL_CONN_IDX_P2P_NET ? NCCL_CONN_IDX_P2P_NET : 0)]; uint64_t sendMask = comm->connectSend[sendPeer+comm->nRanks*(connIndex == NCCL_CONN_IDX_P2P_NET ? NCCL_CONN_IDX_P2P_NET : 0)]; - struct ncclConnect* recvData = data; + // Data[i] contains all ncclConnect information for all send and receive connections with a given send and recv peer + // This data is packed in the array based on the number of sendChannels and recvChannels connected with these peers + // The first N entries contain recvData, connection information for recv connections + // The next M entries contain sendData, connection information for send connections + // It's not guaranteed that each entry of data has the same number of total or send/recv specific connections + data[i] = (ncclConnect*) malloc(sizeof(ncclConnect) * 2*MAXCHANNELS); + recvData[i] = data[i]; int sendChannels = 0, recvChannels = 0; int type; TIME_START(0); for (int c=0; c(comm, graph, recvData+recvChannels++, c, recvPeer, connIndex, &type), ret, fail); + NCCLCHECKGOTO(selectTransport<0>(comm, graph, recvData[i]+recvChannels++, c, recvPeer, connIndex, &type), ret, fail); if (type > highestType) highestType = type; } } TIME_STOP(0); TIME_START(1); - struct ncclConnect* sendData = recvData+recvChannels; + sendData[i] = recvData[i]+recvChannels; for (int c=0; c(comm, graph, sendData+sendChannels++, c, sendPeer, connIndex, &type), ret, fail); + NCCLCHECKGOTO(selectTransport<1>(comm, graph, sendData[i]+sendChannels++, c, sendPeer, connIndex, &type), ret, fail); if (type > highestType) highestType = type; } } @@ -116,54 +125,94 @@ ncclResult_t ncclTransportP2pSetup(struct ncclComm* comm, struct ncclTopoGraph* TIME_START(2); if (sendPeer == recvPeer) { if (recvChannels+sendChannels) { - NCCLCHECKGOTO(bootstrapSend(comm->bootstrap, recvPeer, bootstrapTag, data, sizeof(struct ncclConnect)*(recvChannels+sendChannels)), ret, fail); - NCCLCHECKGOTO(bootstrapRecv(comm->bootstrap, recvPeer, bootstrapTag, data, sizeof(struct ncclConnect)*(recvChannels+sendChannels)), ret, fail); - sendData = data; - recvData = data+sendChannels; + NCCLCHECKGOTO(bootstrapSend(comm->bootstrap, recvPeer, bootstrapTag, data[i], sizeof(struct ncclConnect)*(recvChannels+sendChannels)), ret, fail); + NCCLCHECKGOTO(bootstrapRecv(comm->bootstrap, recvPeer, bootstrapTag, data[i], sizeof(struct ncclConnect)*(recvChannels+sendChannels)), ret, fail); + sendData[i] = data[i]; + recvData[i] = data[i]+sendChannels; } } else { - if (recvChannels) NCCLCHECKGOTO(bootstrapSend(comm->bootstrap, recvPeer, bootstrapTag, recvData, sizeof(struct ncclConnect)*recvChannels), ret, fail); - if (sendChannels) NCCLCHECKGOTO(bootstrapSend(comm->bootstrap, sendPeer, bootstrapTag, sendData, sizeof(struct ncclConnect)*sendChannels), ret, fail); - if (sendChannels) NCCLCHECKGOTO(bootstrapRecv(comm->bootstrap, sendPeer, bootstrapTag, sendData, sizeof(struct ncclConnect)*sendChannels), ret, fail); - if (recvChannels) NCCLCHECKGOTO(bootstrapRecv(comm->bootstrap, recvPeer, bootstrapTag, recvData, sizeof(struct ncclConnect)*recvChannels), ret, fail); + if (recvChannels) NCCLCHECKGOTO(bootstrapSend(comm->bootstrap, recvPeer, bootstrapTag, recvData[i], sizeof(struct ncclConnect)*recvChannels), ret, fail); + if (sendChannels) NCCLCHECKGOTO(bootstrapSend(comm->bootstrap, sendPeer, bootstrapTag, sendData[i], sizeof(struct ncclConnect)*sendChannels), ret, fail); + if (sendChannels) NCCLCHECKGOTO(bootstrapRecv(comm->bootstrap, sendPeer, bootstrapTag, sendData[i], sizeof(struct ncclConnect)*sendChannels), ret, fail); + if (recvChannels) NCCLCHECKGOTO(bootstrapRecv(comm->bootstrap, recvPeer, bootstrapTag, recvData[i], sizeof(struct ncclConnect)*recvChannels), ret, fail); } TIME_STOP(2); - - TIME_START(3); - for (int c=0; cchannels[c].peers[sendPeer].send + connIndex; - NCCLCHECKGOTO(conn->transportComm->connect(comm, sendData++, 1, comm->rank, conn), ret, fail); - conn->connected = 1; - do { - struct ncclConnInfo connInfo; - CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[sendPeer].send[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), ret, fail); - CUDACHECKGOTO(cudaMemcpyAsync(&connInfo, &comm->channels[c].devPeers[sendPeer].send[connIndex], sizeof(struct ncclConnInfo), cudaMemcpyDeviceToHost, comm->hostStream.cudaStream), ret, fail); - CUDACHECKGOTO(hipStreamSynchronize(comm->hostStream.cudaStream), ret, fail); - if (memcmp(&connInfo, &conn->conn, sizeof(struct ncclConnInfo)) == 0) break; - } while (true); - } - } - TIME_STOP(3); - TIME_START(4); - for (int c=0; cchannels[c].peers[recvPeer].recv + connIndex; - NCCLCHECKGOTO(conn->transportComm->connect(comm, recvData++, 1, comm->rank, conn), ret, fail); - conn->connected = 1; - do { - struct ncclConnInfo connInfo; - CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[recvPeer].recv[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), ret, fail); - CUDACHECKGOTO(cudaMemcpyAsync(&connInfo, &comm->channels[c].devPeers[recvPeer].recv[connIndex], sizeof(struct ncclConnInfo), cudaMemcpyDeviceToHost, comm->hostStream.cudaStream), ret, fail); - CUDACHECKGOTO(hipStreamSynchronize(comm->hostStream.cudaStream), ret, fail); - if (memcmp(&connInfo, &conn->conn, sizeof(struct ncclConnInfo)) == 0) break; - } while (true); - } - } - TIME_STOP(4); - comm->connectRecv[recvPeer+comm->nRanks*(connIndex == NCCL_CONN_IDX_P2P_NET ? NCCL_CONN_IDX_P2P_NET : 0)] = comm->connectSend[sendPeer+comm->nRanks*(connIndex == NCCL_CONN_IDX_P2P_NET ? NCCL_CONN_IDX_P2P_NET : 0)] = 0UL; } + // Loop until all channels with all ranks have been connected + bool allChannelsConnected; + allChannelsConnected = false; + while (!allChannelsConnected) { + allChannelsConnected = true; + for (int i=1; inRanks; i++) { + int recvPeer = (comm->rank - i + comm->nRanks) % comm->nRanks; + int sendPeer = (comm->rank + i) % comm->nRanks; + uint64_t recvMask = comm->connectRecv[recvPeer+comm->nRanks*(connIndex == NCCL_CONN_IDX_P2P_NET ? NCCL_CONN_IDX_P2P_NET : 0)]; + uint64_t sendMask = comm->connectSend[sendPeer+comm->nRanks*(connIndex == NCCL_CONN_IDX_P2P_NET ? NCCL_CONN_IDX_P2P_NET : 0)]; + + int sendDataOffset = 0; + int recvDataOffset = 0; + for (int c=0; cchannels[c].peers[sendPeer].send + connIndex; + // This connector hasn't completed connection yet + if (conn->connected == 0) { + NCCLCHECKGOTO(conn->transportComm->connect(comm, sendData[i] + sendDataOffset++, 1, comm->rank, conn), ret, fail); + if (ret == ncclSuccess) { + conn->connected = 1; + do { + struct ncclConnInfo connInfo; + CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[sendPeer].send[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), ret, fail); + CUDACHECKGOTO(cudaMemcpyAsync(&connInfo, &comm->channels[c].devPeers[sendPeer].send[connIndex], sizeof(struct ncclConnInfo), cudaMemcpyDeviceToHost, comm->hostStream.cudaStream), ret, fail); + CUDACHECKGOTO(hipStreamSynchronize(comm->hostStream.cudaStream), ret, fail); + if (memcmp(&connInfo, &conn->conn, sizeof(struct ncclConnInfo)) == 0) break; + } while (true); + } else if (ret == ncclInProgress) { + allChannelsConnected = false; + } + } + } + TIME_STOP(3); + + // Start with recv channels + TIME_START(4); + if (recvMask & (1UL<channels[c].peers[recvPeer].recv + connIndex; + // This connector hasn't completed connection yet + if (conn->connected == 0) { + NCCLCHECKGOTO(conn->transportComm->connect(comm, recvData[i] + recvDataOffset++, 1, comm->rank, conn), ret, fail); + if (ret == ncclSuccess) { + conn->connected = 1; + do { + struct ncclConnInfo connInfo; + CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[recvPeer].recv[connIndex], &conn->conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), ret, fail); + CUDACHECKGOTO(cudaMemcpyAsync(&connInfo, &comm->channels[c].devPeers[recvPeer].recv[connIndex], sizeof(struct ncclConnInfo), cudaMemcpyDeviceToHost, comm->hostStream.cudaStream), ret, fail); + CUDACHECKGOTO(hipStreamSynchronize(comm->hostStream.cudaStream), ret, fail); + if (memcmp(&connInfo, &conn->conn, sizeof(struct ncclConnInfo)) == 0) break; + } while (true); + } else if (ret == ncclInProgress) { + allChannelsConnected = false; + } + } + } + TIME_STOP(4); + } + } + } + + // Clear all connect masks and free each connectInfo array + for (int i=1; inRanks; i++) { + int recvPeer = (comm->rank - i + comm->nRanks) % comm->nRanks; + int sendPeer = (comm->rank + i) % comm->nRanks; + comm->connectRecv[recvPeer+comm->nRanks*(connIndex == NCCL_CONN_IDX_P2P_NET ? NCCL_CONN_IDX_P2P_NET : 0)] = comm->connectSend[sendPeer+comm->nRanks*(connIndex == NCCL_CONN_IDX_P2P_NET ? NCCL_CONN_IDX_P2P_NET : 0)] = 0UL; + free(data[i]); + } + + free(data); + free(sendData); + free(recvData); + if (highestTransportType != NULL) *highestTransportType = highestType; TIME_PRINT("P2P Setup/Connect"); exit: diff --git a/src/transport/coll_net.cc b/src/transport/coll_net.cc index b0c0449d29..7a5e012e02 100644 --- a/src/transport/coll_net.cc +++ b/src/transport/coll_net.cc @@ -155,13 +155,13 @@ static ncclResult_t sendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph int proxyRank; NCCLCHECK(ncclTopoGetNetDev(comm, myInfo->rank, graph, channelId, -1, &req.netDev, &proxyRank)); NCCLCHECK(ncclTopoCheckGdr(comm->topo, myInfo->busId, req.netDev, 1, &req.useGdr)); - send->conn.direct |= req.useGdr ? NCCL_DIRECT_NIC : 0; + send->conn.flags |= req.useGdr ? NCCL_DIRECT_NIC : 0; // Determine whether we need to flush the GDR buffer on recv or not if (req.useGdr) NCCLCHECK(ncclTopoNeedFlush(comm->topo, myInfo->busId, &req.needFlush)); NCCLCHECK(ncclTopoGetLocalRank(comm->topo, myInfo->rank, &send->proxyConn.localRank)); NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_COLLNET, 1, myInfo->rank, &send->proxyConn)); - NCCLCHECK(ncclProxyCall(&send->proxyConn, ncclProxyMsgSetup, &req, sizeof(req), NULL, 0)); + NCCLCHECK(ncclProxyCallBlocking(&send->proxyConn, ncclProxyMsgSetup, &req, sizeof(req), NULL, 0)); INFO(NCCL_INIT|NCCL_NET,"CollNet %02d/%1d : %d [send] via COLLNET/%s/%d%s comm %p nRanks %02d", channelId, connIndex, myInfo->rank, collNetName(comm), req.netDev, req.useGdr ? "/GDRDMA" : "", comm, comm->nRanks); @@ -174,12 +174,12 @@ static ncclResult_t recvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph int proxyRank; NCCLCHECK(ncclTopoGetNetDev(comm, myInfo->rank, graph, channelId, -1, &req.netDev, &proxyRank)); NCCLCHECK(ncclTopoCheckGdr(comm->topo, myInfo->busId, req.netDev, 0, &req.useGdr)); - recv->conn.direct |= req.useGdr ? NCCL_DIRECT_NIC : 0; + recv->conn.flags |= req.useGdr ? NCCL_DIRECT_NIC : 0; NCCLCHECK(ncclTopoGetLocalRank(comm->topo, myInfo->rank, &recv->proxyConn.localRank)); NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_COLLNET, 0, myInfo->rank, &recv->proxyConn)); struct collNetRecvConnectInfo* info = (struct collNetRecvConnectInfo*) connectInfo; - NCCLCHECK(ncclProxyCall(&recv->proxyConn, ncclProxyMsgSetup, &req, sizeof(req), &info->collNetHandle, sizeof(collNetHandle_t))); + NCCLCHECK(ncclProxyCallBlocking(&recv->proxyConn, ncclProxyMsgSetup, &req, sizeof(req), &info->collNetHandle, sizeof(collNetHandle_t))); INFO(NCCL_INIT|NCCL_NET,"CollNet %02d/%1d : %d [receive] via COLLNET/%s/%d%s comm %p nRanks %02d", channelId, connIndex, myInfo->rank, collNetName(comm), req.netDev, req.useGdr ? "/GDRDMA" : "", comm, comm->nRanks); @@ -224,7 +224,7 @@ static ncclResult_t sendConnect(struct ncclComm* comm, struct ncclConnect* conne // We're on the same process as the proxy. We can pass a pointer to a struct. struct collNetConnectArgs args = { rank, nranks, connectInfos }; struct connectMap* map; - NCCLCHECK(ncclProxyCall(&send->proxyConn, ncclProxyMsgConnect, &args, sizeof(struct collNetConnectArgs), &map, sizeof(struct connectMap*))); + NCCLCHECK(ncclProxyCallBlocking(&send->proxyConn, ncclProxyMsgConnect, &args, sizeof(struct collNetConnectArgs), &map, sizeof(struct connectMap*))); // If collnet connect failed, propagate error to fallback on regular p2p if (map == NULL) return ncclSystemError; @@ -250,7 +250,7 @@ static ncclResult_t recvConnect(struct ncclComm* comm, struct ncclConnect* conne // We're on the same process as the proxy. We can pass a pointer to a struct. struct collNetConnectArgs args = { rank, nranks, connectInfos }; struct connectMap* map; - NCCLCHECK(ncclProxyCall(&recv->proxyConn, ncclProxyMsgConnect, &args, sizeof(struct collNetConnectArgs), &map, sizeof(struct connectMap*))); + NCCLCHECK(ncclProxyCallBlocking(&recv->proxyConn, ncclProxyMsgConnect, &args, sizeof(struct collNetConnectArgs), &map, sizeof(struct connectMap*))); // If collnet connect failed, propagate error to fallback on regular p2p if (map == NULL) return ncclSystemError; @@ -413,7 +413,7 @@ static ncclResult_t recvProxySetup(struct ncclProxyConnection* connection, struc } static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, struct ncclComm* comm, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) { - if (reqSize != sizeof(struct collNetConnectArgs)) { WARN("sendProxyConnect: reqSize is %d != %ld\n", reqSize, sizeof(struct collNetConnectArgs)); return ncclInternalError; } + if (reqSize != sizeof(struct collNetConnectArgs)) { WARN("sendProxyConnect: reqSize is %d != %ld", reqSize, sizeof(struct collNetConnectArgs)); return ncclInternalError; } struct collNetConnectArgs* args = (struct collNetConnectArgs*)reqBuff; struct collNetSendConnectInfo* info = (struct collNetSendConnectInfo*)(args->connectInfos+args->rank); @@ -429,7 +429,7 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str NCCLCHECK(sharedConnect(comm, resources->netDev, args->connectInfos, args->nranks, args->rank, &resources->collNetComm)); // Collnet connect is allowed to fail. Gracefully handle that case by returning NULL to the caller. - if (respSize != sizeof(struct connectMap*)) { WARN("sendProxyConnect: respSize is %d != %ld\n", respSize, sizeof(void*)); return ncclInternalError; } + if (respSize != sizeof(struct connectMap*)) { WARN("sendProxyConnect: respSize is %d != %ld", respSize, sizeof(void*)); return ncclInternalError; } if (resources->collNetComm == NULL) { *((struct connectMap**)respBuff) = NULL; return ncclSuccess; @@ -487,7 +487,7 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str } static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, struct ncclComm* comm, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) { - if (reqSize != sizeof(struct collNetConnectArgs)) { WARN("recvProxyConnect: reqSize is %d != %ld\n", reqSize, sizeof(struct collNetConnectArgs)); return ncclInternalError; } + if (reqSize != sizeof(struct collNetConnectArgs)) { WARN("recvProxyConnect: reqSize is %d != %ld", reqSize, sizeof(struct collNetConnectArgs)); return ncclInternalError; } struct collNetConnectArgs* args = (struct collNetConnectArgs*)reqBuff; struct recvResources* resources = (struct recvResources*)(connection->transportResources); @@ -497,7 +497,7 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str NCCLCHECK(sharedConnect(comm, resources->netDev, args->connectInfos, args->nranks, args->rank, &resources->collNetComm)); // Collnet connect is allowed to fail. Gracefully handle that case by returning NULL to the caller. - if (respSize != sizeof(struct connectMap*)) { WARN("sendProxyConnect: respSize is %d != %ld\n", respSize, sizeof(void*)); return ncclInternalError; } + if (respSize != sizeof(struct connectMap*)) { WARN("sendProxyConnect: respSize is %d != %ld", respSize, sizeof(void*)); return ncclInternalError; } if (resources->collNetComm == NULL) { *((struct connectMap**)respBuff) = NULL; return ncclSuccess; @@ -556,7 +556,7 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str for (int p=0; pmhandles[p] = resources->mhandles[p]; - if (respSize != sizeof(struct connectMap*)) { WARN("recvProxyConnect: respSize is %d != %ld\n", respSize, sizeof(void*)); return ncclInternalError; } + if (respSize != sizeof(struct connectMap*)) { WARN("recvProxyConnect: respSize is %d != %ld", respSize, sizeof(void*)); return ncclInternalError; } *((struct connectMap**)respBuff) = &resources->map; return ncclSuccess; } diff --git a/src/transport/net.cc b/src/transport/net.cc index 0ddf6438c0..ec7a2ffaba 100644 --- a/src/transport/net.cc +++ b/src/transport/net.cc @@ -189,7 +189,7 @@ static ncclResult_t sendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph if (connIndex == NCCL_CONN_IDX_P2P_NET) NCCLCHECK(ncclTopoGetIntraNetDev(comm->topo, myInfo->rank, graph, channelId, 1, &req.netDev)); if (req.netDev < 0) NCCLCHECK(ncclTopoGetNetDev(comm, myInfo->rank, graph, channelId, peerInfo->rank, &req.netDev, &proxyRank)); NCCLCHECK(ncclTopoCheckGdr(comm->topo, myInfo->busId, req.netDev, 1, &req.useGdr)); - send->conn.direct |= req.useGdr ? NCCL_DIRECT_NIC : 0; + send->conn.flags |= req.useGdr ? NCCL_DIRECT_NIC : 0; if (req.useGdr && comm->topo->nodes[GPU].nodes[0].gpu.gcn != 910) { CUDACHECK(hipDeviceGetAttribute((int*)&req.curr_hdp_reg, hipDeviceAttributeHdpMemFlushCntl, myInfo->cudaDev)); send->conn.curr_hdp_reg = req.curr_hdp_reg; @@ -199,7 +199,7 @@ static ncclResult_t sendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph req.rank = myInfo->rank; NCCLCHECK(ncclTopoGetLocalRank(comm->topo, myInfo->rank, &req.localRank)); req.remoteRank = peerInfo->rank; - NCCLCHECK(ncclProxyCall(&send->proxyConn, ncclProxyMsgSetup, &req, sizeof(req), NULL, 0)); + NCCLCHECK(ncclProxyCallBlocking(&send->proxyConn, ncclProxyMsgSetup, &req, sizeof(req), NULL, 0)); if (proxyRank == myInfo->rank) { INFO(NCCL_INIT|NCCL_NET,"Channel %02d/%d : %d[%lx] -> %d[%lx] [send] via NET/%s/%d%s%s comm %p nRanks %02d", channelId, connIndex, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, ncclNetName(comm), req.netDev, @@ -241,8 +241,7 @@ static ncclResult_t recvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph req.rank = myInfo->rank; NCCLCHECK(ncclTopoGetLocalRank(comm->topo, myInfo->rank, &req.localRank)); req.remoteRank = peerInfo->rank; - NCCLCHECK(ncclProxyCall(&recv->proxyConn, ncclProxyMsgSetup, &req, sizeof(req), connectInfo, sizeof(ncclNetHandle_t))); - + NCCLCHECK(ncclProxyCallBlocking(&recv->proxyConn, ncclProxyMsgSetup, &req, sizeof(req), connectInfo, sizeof(ncclNetHandle_t))); INFO(NCCL_INIT|NCCL_NET,"Channel %02d/%d : %d[%lx] -> %d[%lx] [receive] via NET/%s/%d%s%s comm %p nRanks %02d", channelId, connIndex, peerInfo->rank, peerInfo->busId, myInfo->rank, myInfo->busId, ncclNetName(comm), req.netDev, req.useGdr ? "/GDRDMA" : "", req.shared ? "/Shared" : "", comm, comm->nRanks); return ncclSuccess; @@ -287,11 +286,28 @@ static ncclResult_t netDumpMap(struct connectMap* map) { } static ncclResult_t sendConnect(struct ncclComm* comm, struct ncclConnect* connectInfo, int nranks, int rank, struct ncclConnector* send) { - // Setup device pointers - struct connectMap* map; - NCCLCHECK(ncclCalloc(&map, 1)); - send->transportResources = map; - NCCLCHECK(ncclProxyCall(&send->proxyConn, ncclProxyMsgConnect, connectInfo, sizeof(ncclNetHandle_t), map, sizeof(struct connectMap))); + struct connectMap* map = (connectMap*) send->transportResources; + + void* opId; + + // map isn't allocated thus this op hasn't been submitted yet + if (!map) { + // Setup device pointers + NCCLCHECK(ncclCalloc(&map, 1)); + send->transportResources = map; + opId = send; + INFO(NCCL_PROXY, "sendConnect ncclProxyCallAsync opId=%p", opId); + NCCLCHECK(ncclProxyCallAsync(&send->proxyConn, ncclProxyMsgConnect, connectInfo, sizeof(ncclNetHandle_t), sizeof(struct connectMap), opId)); + } else { + opId = send; + } + + ncclResult_t ret; + NCCLCHECK(ret = ncclPollProxyResponse(&send->proxyConn, map, opId)); + if (ret == ncclInProgress) { + return ret; + } + INFO(NCCL_PROXY, "sendConnect ncclPollProxyResponse opId=%p", opId); if (map->sameProcess) { if (map->cudaDev != comm->cudaDev) { @@ -338,10 +354,26 @@ static ncclResult_t sendConnect(struct ncclComm* comm, struct ncclConnect* conne /* Connect to this peer */ static ncclResult_t recvConnect(struct ncclComm* comm, struct ncclConnect* connectInfo, int nranks, int rank, struct ncclConnector* recv) { - struct connectMap* map; - NCCLCHECK(ncclCalloc(&map, 1)); - recv->transportResources = map; - NCCLCHECK(ncclProxyCall(&recv->proxyConn, ncclProxyMsgConnect, connectInfo, sizeof(int), map, sizeof(struct connectMap))); + struct connectMap* map = (connectMap*) recv->transportResources; + void* opId; + if (!map) { + NCCLCHECK(ncclCalloc(&map, 1)); + recv->transportResources = map; + // Use recv connector as unique identifier + opId = recv; + INFO(NCCL_PROXY, "recvConnect ncclProxyCallAsync opId=%p &recv->proxyConn=%p connectInfo=%p", + opId, &recv->proxyConn, connectInfo); + NCCLCHECK(ncclProxyCallAsync(&recv->proxyConn, ncclProxyMsgConnect, connectInfo, sizeof(int), sizeof(struct connectMap), opId)); + } else { + opId = recv; + } + + ncclResult_t ret; + NCCLCHECK(ret = ncclPollProxyResponse(&recv->proxyConn, map, opId)); + if (ret == ncclInProgress) { + return ret; + } + INFO(NCCL_PROXY, "recvConnect ncclPollProxyResponse opId=%p", opId); //NCCLCHECK(netDumpMap(map)); struct ncclSendMem *sendMem = (struct ncclSendMem*) NCCL_NET_MAP_GET_POINTER(map, gpu, sendMem); @@ -514,12 +546,14 @@ static ncclResult_t recvProxySetup(struct ncclProxyConnection* connection, struc if (respSize != sizeof(ncclNetHandle_t)) return ncclInternalError; NCCLCHECK(ncclNetListen(comm, req->netDev, respBuff, &resources->netListenComm)); *done = 1; + return ncclSuccess; } static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, struct ncclComm* comm, void* reqBuff, int reqSize, void* respBuff, int respSize, int* done) { struct sendResources* resources = (struct sendResources*)(connection->transportResources); if (reqSize != sizeof(ncclNetHandle_t)) return ncclInternalError; + ncclResult_t ret = ncclSuccess; if (resources->shared) { // Shared buffers @@ -539,21 +573,22 @@ static ncclResult_t sendProxyConnect(struct ncclProxyConnection* connection, str NCCLCHECK(ncclCalloc(progressState->netComms+resources->netDev, comm->nRanks)); } struct ncclSharedNetComms* comms = progressState->netComms[resources->netDev]+resources->remoteRank; - if (comms->sendComm[resources->channelId] == NULL) NCCLCHECK(ncclNetConnect(comm, resources->netDev, reqBuff, comms->sendComm+resources->channelId)); + if (comms->sendComm[resources->channelId] == NULL) ret = ncclNetConnect(comm, resources->netDev, reqBuff, comms->sendComm+resources->channelId); resources->netSendComm = comms->sendComm[resources->channelId]; if (comms->sendComm[resources->channelId]) comms->sendRefCount[resources->channelId]++; } else { - NCCLCHECK(ncclNetConnect(comm, resources->netDev, reqBuff, &resources->netSendComm)); + ret = ncclNetConnect(comm, resources->netDev, reqBuff, &resources->netSendComm); } } else { // Connect to remote peer - NCCLCHECK(ncclNetConnect(comm, resources->netDev, reqBuff, &resources->netSendComm)); + ret = ncclNetConnect(comm, resources->netDev, reqBuff, &resources->netSendComm); connection->proxyAppendPtr = &connection->proxyAppend; } + NCCLCHECK(ret); if (resources->netSendComm == NULL) { *done = 0; - return ncclSuccess; + return ncclInProgress; } *done = 1; @@ -666,6 +701,7 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str if (reqSize != sizeof(int)) return ncclInternalError; struct recvResources* resources = (struct recvResources*)(connection->transportResources); resources->proxyRank = *(int*)reqBuff; + ncclResult_t ret = ncclSuccess; // Finish connection establishment from remote peer if (resources->shared) { @@ -686,23 +722,25 @@ static ncclResult_t recvProxyConnect(struct ncclProxyConnection* connection, str NCCLCHECK(ncclCalloc(progressState->netComms+resources->netDev, comm->nRanks)); } struct ncclSharedNetComms* comms = progressState->netComms[resources->netDev]+resources->proxyRank; - if (comms->recvComm[resources->channelId] == NULL) NCCLCHECK(ncclNetAccept(comm, resources->netListenComm, comms->recvComm+resources->channelId)); + if (comms->recvComm[resources->channelId] == NULL) ret = ncclNetAccept(comm, resources->netListenComm, comms->recvComm+resources->channelId); resources->netRecvComm = comms->recvComm[resources->channelId]; if (comms->recvComm[resources->channelId]) comms->recvRefCount[resources->channelId]++; } else { - NCCLCHECK(ncclNetAccept(comm, resources->netListenComm, &resources->netRecvComm)); + ret = ncclNetAccept(comm, resources->netListenComm, &resources->netRecvComm); } } else { // Connect to remote peer - NCCLCHECK(ncclNetAccept(comm, resources->netListenComm, &resources->netRecvComm)); + ret = ncclNetAccept(comm, resources->netListenComm, &resources->netRecvComm); connection->proxyAppendPtr = &connection->proxyAppend; } + NCCLCHECK(ret); if (resources->netRecvComm == NULL) { *done = 0; - return ncclSuccess; + return ncclInProgress; } *done = 1; + NCCLCHECK(ncclNetCloseListen(comm, resources->netListenComm)); // Create structures diff --git a/src/transport/net_ib.cc b/src/transport/net_ib.cc index b900188c28..a01f391133 100644 --- a/src/transport/net_ib.cc +++ b/src/transport/net_ib.cc @@ -385,7 +385,9 @@ enum ncclIbCommState { ncclIbCommStateAccept = 3, ncclIbCommStateSend = 4, ncclIbCommStateRecv = 5, - ncclIbCommStateConnected = 6, + ncclIbCommStateConnecting = 6, + ncclIbCommStateConnected = 7, + ncclIbCommStatePendingReady = 8, }; struct ncclIbCommStage { @@ -633,8 +635,10 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm) { int ready; *sendComm = NULL; - if (stage->state == ncclIbCommStateConnect) goto ib_connect_check; - if (stage->state == ncclIbCommStateSend) goto ib_send; + if (stage->state == ncclIbCommStateConnect) goto ib_connect_check; + if (stage->state == ncclIbCommStateSend) goto ib_send; + if (stage->state == ncclIbCommStateConnecting) goto ib_connect; + if (stage->state == ncclIbCommStateConnected) goto ib_send_ready; if (stage->state != ncclIbCommStateStart) { WARN("Error: trying to connect already connected sendComm"); return ncclInternalError; @@ -698,11 +702,37 @@ ib_connect_check: ib_send: NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_SEND, &comm->sock, stage->buffer, sizeof(qpInfo), &stage->offset)); - if (stage->offset != sizeof(qpInfo)) - return ncclSuccess; + if (stage->offset != sizeof(qpInfo)) return ncclSuccess; + + stage->state = ncclIbCommStateConnecting; + stage->offset = 0; + // Clear the staging buffer for re-use + memset(stage->buffer, 0, sizeof(qpInfo)); + +ib_connect: + struct ncclIbQpInfo remQpInfo; + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &comm->sock, stage->buffer, sizeof(ncclIbQpInfo), &stage->offset)); + if (stage->offset != sizeof(remQpInfo)) return ncclSuccess; + + memcpy(&remQpInfo, stage->buffer, sizeof(ncclIbQpInfo)); + + for (int q=0; qnqps; q++) { + struct ibv_qp* qp = comm->qps[q]; + NCCLCHECK(ncclIbRtrQp(qp, remQpInfo.qpn[q], &remQpInfo)); + NCCLCHECK(ncclIbRtsQp(qp)); + } + + comm->ready = 1; + stage->state = ncclIbCommStateConnected; + stage->offset = 0; + +ib_send_ready: + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_SEND, &comm->sock, &comm->ready, sizeof(int), &stage->offset)); + if (stage->offset != sizeof(int)) return ncclSuccess; free(stage->buffer); - stage->state = ncclIbCommStateConnected; + stage->state = ncclIbCommStateStart; + *sendComm = comm; return ncclSuccess; } @@ -719,8 +749,9 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm) { if (stage->state == ncclIbCommStateAccept) goto ib_accept_check; if (stage->state == ncclIbCommStateRecv) goto ib_recv; if (stage->state == ncclIbCommStateSend) goto ib_send; + if (stage->state == ncclIbCommStatePendingReady) goto ib_recv_ready; if (stage->state != ncclIbCommStateStart) { - WARN("Listencomm in unknown state %d\n", stage->state); + WARN("Listencomm in unknown state %d", stage->state); return ncclInternalError; } @@ -738,10 +769,10 @@ ib_accept_check: stage->state = ncclIbCommStateRecv; stage->offset = 0; NCCLCHECK(ncclIbMalloc((void**)&stage->buffer, sizeof(remQpInfo))); + ib_recv: NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &rComm->sock, stage->buffer, sizeof(remQpInfo), &stage->offset)); - if (stage->offset != sizeof(remQpInfo)) - return ncclSuccess; + if (stage->offset != sizeof(remQpInfo)) return ncclSuccess; /* copy back the received info */ memcpy(&remQpInfo, stage->buffer, sizeof(struct ncclIbQpInfo)); @@ -814,10 +845,18 @@ ib_recv: if (stage->buffer) free(stage->buffer); NCCLCHECK(ncclIbMalloc((void**)&stage->buffer, sizeof(struct ncclIbQpInfo))); memcpy(stage->buffer, &qpInfo, sizeof(struct ncclIbQpInfo)); + ib_send: NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_SEND, &rComm->sock, stage->buffer, sizeof(struct ncclIbQpInfo), &stage->offset)); if (stage->offset < sizeof(struct ncclIbQpInfo)) return ncclSuccess; + stage->offset = 0; + stage->state = ncclIbCommStatePendingReady; + +ib_recv_ready: + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &rComm->sock, &rComm->ready, sizeof(int), &stage->offset)); + if (stage->offset != sizeof(int)) return ncclSuccess; + free(stage->buffer); *recvComm = rComm; @@ -849,36 +888,6 @@ ncclResult_t ncclIbFreeRequest(struct ncclIbRequest* r) { return ncclSuccess; } -ncclResult_t ncclSendCheck(struct ncclIbSendComm* comm) { - struct ncclIbQpInfo remQpInfo; - - // Do not block on this receive, return if not ready. - int bytes = 0; - NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &comm->sock, &remQpInfo, sizeof(remQpInfo), &bytes)); - if (bytes == 0) return ncclSuccess; // Try again later - NCCLCHECK(ncclSocketWait(NCCL_SOCKET_RECV, &comm->sock, &remQpInfo, sizeof(remQpInfo), &bytes)); - - for (int q=0; qnqps; q++) { - struct ibv_qp* qp = comm->qps[q]; - NCCLCHECK(ncclIbRtrQp(qp, remQpInfo.qpn[q], &remQpInfo)); - NCCLCHECK(ncclIbRtsQp(qp)); - } - comm->ready = 1; - // Block until this is done. It *should* not block indefinitely. - NCCLCHECK(ncclSocketSend(&comm->sock, &comm->ready, sizeof(int))); - - return ncclSuccess; -} - -ncclResult_t ncclRecvCheck(struct ncclIbRecvComm* comm) { - // Do not block on this receive, return if not ready. - int bytes = 0; - NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &comm->sock, &comm->ready, sizeof(int), &bytes)); - if (bytes == 0) return ncclSuccess; // Try again later - NCCLCHECK(ncclSocketWait(NCCL_SOCKET_RECV, &comm->sock, &comm->ready, sizeof(int), &bytes)); - return ncclSuccess; -} - ncclResult_t ncclIbTest(void* request, int* done, int* size); /* DMA-BUF support */ @@ -1054,7 +1063,7 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, int tag, void* mhandle, void** request) { struct ncclIbSendComm* comm = (struct ncclIbSendComm*)sendComm; - if (comm->ready == 0) NCCLCHECK(ncclSendCheck(comm)); + if (comm->ready == 0) { WARN("NET/IB: ncclIbIsend() called when comm->ready == 0"); return ncclInternalError; } if (comm->ready == 0) { *request = NULL; return ncclSuccess; } struct ibv_mr* mr = (struct ibv_mr*)mhandle; @@ -1187,7 +1196,7 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, int ncclResult_t ncclIbIrecv(void* recvComm, int n, void** data, int* sizes, int* tags, void** mhandles, void** request) { struct ncclIbRecvComm* comm = (struct ncclIbRecvComm*)recvComm; - if (comm->ready == 0) NCCLCHECK(ncclRecvCheck(comm)); + if (comm->ready == 0) { WARN("NET/IB: ncclIbIrecv() called when comm->ready == 0"); return ncclInternalError; } if (comm->ready == 0) { *request = NULL; return ncclSuccess; } if (n > NCCL_NET_IB_MAX_RECVS) return ncclInternalError; diff --git a/src/transport/nvls.cc b/src/transport/nvls.cc new file mode 100644 index 0000000000..336877ce2b --- /dev/null +++ b/src/transport/nvls.cc @@ -0,0 +1,373 @@ +/************************************************************************* + * Copyright (c) 2016-2023, NVIDIA CORPORATION. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ + +// Implementation of the NVLink SHARP (NVLS) transport + +#include "comm.h" +#include "graph.h" +#include "utils.h" +#include "proxy.h" + +#if CUDART_VERSION >= 12010 + +// Currently we only support POSIX_FILE_DESCRIPTOR handle exchange +#define USE_POSIX_FD 1 + +#if USE_POSIX_FD +#define NVLS_CU_MEM_HANDLE_TYPE CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR +#else +#define NVLS_CU_MEM_HANDLE_TYPE CU_MEM_HANDLE_TYPE_NONE +#endif + +ncclResult_t nvlsCanConnect(int* ret, struct ncclTopoSystem* topo, struct ncclTopoGraph* graph, struct ncclPeerInfo* info1, struct ncclPeerInfo* info2) { + // This transport cannot be used for p2p + *ret = 0; + return ncclSuccess; +} + +ncclResult_t nvlsSendFree(struct ncclConnector* send) { + return ncclSuccess; +} + +ncclResult_t nvlsRecvFree(struct ncclConnector* recv) { + return ncclSuccess; +} + +struct ncclTransport nvlsTransport = { + "NVLS", + nvlsCanConnect, + { NULL, NULL, nvlsSendFree, NULL, NULL, NULL, NULL, NULL }, + { NULL, NULL, nvlsRecvFree, NULL, NULL, NULL, NULL, NULL } +}; + +#define NVLS_HANDLE_SIZE 64 + +struct nvlsResources { + CUmulticastObjectProp properties; + CUmemAccessDesc accessDesc; + int dev; + size_t size; + size_t granularity; + CUmemGenericAllocationHandle mcHandle; // Multicast handle for NVLS buffer + char* mcBuff; // Multicast NVLS buffer address + CUmemGenericAllocationHandle ucHandle; // Unicast Handle for NVLS buffer + char* ucBuff; // Unicast NVLS buffer address +}; + + +ncclResult_t nvlsGetProperties(struct ncclComm *comm, struct nvlsResources* resources, int dev, int nranks, size_t size) { + CUmulticastObjectProp* prop = &resources->properties; + memset(prop, 0, sizeof(*prop)); + prop->size = size; + prop->numDevices = nranks; + prop->handleTypes = NVLS_CU_MEM_HANDLE_TYPE; + prop->flags = 0; + + // Could be changed to CU_MULTICAST_GRANULARITY_MINIMUM when 3418538 resolved + CUCHECK(cuMulticastGetGranularity(&resources->granularity, prop, CU_MULTICAST_GRANULARITY_RECOMMENDED)); + + ALIGN_SIZE(size, resources->granularity); + prop->size = resources->size = size; + + memset(&resources->accessDesc, 0, sizeof(resources->accessDesc)); + resources->accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + resources->accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + resources->accessDesc.location.id = dev; + resources->dev = dev; + + return ncclSuccess; +} + +ncclResult_t nvlsGroupCreate(struct ncclComm *comm, struct nvlsResources* resources, int rank, unsigned int nranks, char* shareableHandle) { + size_t size = resources->size; + + // Create a Multicast group + CUmulticastObjectProp* prop = &resources->properties; + + INFO(NCCL_NVLS, "NVLS Creating Multicast group nranks %d size %zi on rank %d", nranks, size, rank); + CUCHECK(cuMulticastCreate(&resources->mcHandle, prop)); + + if (NVLS_CU_MEM_HANDLE_TYPE != CU_MEM_HANDLE_TYPE_NONE) { + // Get a handle to pass to other ranks + CUCHECK(cuMemExportToShareableHandle(shareableHandle, resources->mcHandle, NVLS_CU_MEM_HANDLE_TYPE, 0)); + } + else { + memcpy(shareableHandle, &resources->mcHandle, sizeof(resources->mcHandle)); + } + + INFO(NCCL_NVLS, "NVLS Created Multicast group %llx nranks %d size %zi on rank %d", resources->mcHandle, nranks, size, rank); + + return ncclSuccess; +} + +ncclResult_t nvlsGroupAddDevice(struct ncclComm *comm, struct nvlsResources* resources) { + INFO(NCCL_NVLS, "NVLS group %llx adding dev %d", resources->mcHandle, resources->dev); + CUCHECK(cuMulticastAddDevice(resources->mcHandle, resources->dev)); + return ncclSuccess; +} + +ncclResult_t nvlsGroupUnbind(struct ncclComm *comm, struct nvlsResources* resources) { + int dev = resources->dev; + size_t size = resources->size; + INFO(NCCL_NVLS, "NVLS Unbind MC handle %llx size %zi dev %d", resources->mcHandle, size, dev); + + // Unbind physical memory from group for the given device + CUCHECK(cuMulticastUnbind(resources->mcHandle, dev, 0/*mcOffset*/, size)); + + return ncclSuccess; +} + +ncclResult_t nvlsGroupConnect(struct ncclComm *comm, struct nvlsResources* resources, int rank, char* shareableHandle) { + CUmemAllocationHandleType type = NVLS_CU_MEM_HANDLE_TYPE; + + INFO(NCCL_NVLS, "NVLS importing shareableHandle %p from rank %d", shareableHandle, rank); + + // Import and map the remote memory descriptor to the local GPU + if (type == CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR) { + // cuMem UDS support + int fd = *(int *)shareableHandle; + TRACE(NCCL_NVLS, "NVLS rank %d Importing shareable handle from rank %d fd %d", comm->localRank, rank, fd); + struct ncclProxyConnector proxyConn; + NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_P2P, 1, rank, &proxyConn)); + TRACE(NCCL_NVLS, "NVLS rank %d request conversion of fd %d from rank %d", comm->localRank, fd, rank); + NCCLCHECK(ncclProxyCallBlocking(&proxyConn, ncclProxyMsgConvertFd, shareableHandle, sizeof(int), &fd, sizeof(int))); + TRACE(NCCL_NVLS, "NVLS rank %d received converted fd %d from rank %d", comm->localRank, fd, rank); + CUCHECK(cuMemImportFromShareableHandle(&resources->mcHandle, (void *)(uintptr_t)fd, type)); + } else { + if (NVLS_CU_MEM_HANDLE_TYPE != CU_MEM_HANDLE_TYPE_NONE) { + CUCHECK(cuMemImportFromShareableHandle(&resources->mcHandle, (void *)shareableHandle, type)); + } else { + memcpy(&resources->mcHandle, shareableHandle, sizeof(resources->mcHandle)); + } + } + return ncclSuccess; +} + +ncclResult_t nvlsGroupBindMem(struct ncclComm *comm, struct nvlsResources* resources) { + size_t size = resources->size; + size_t granularity; + CUdeviceptr ptr = 0; + CUmemAllocationProp prop; + + memset(&prop, 0, sizeof(prop)); + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = resources->dev; + prop.requestedHandleTypes = NVLS_CU_MEM_HANDLE_TYPE; + CUCHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED)); + + // Map a VA for UC memory + CUCHECK(cuMemAddressReserve(&ptr, size, granularity, 0U, 0)); + + // Alloc local physical mem for this NVLS group + CUCHECK(cuMemCreate(&resources->ucHandle, size, &prop, 0)); + CUCHECK(cuMemMap(ptr, size, 0, resources->ucHandle, 0)); + CUCHECK(cuMemSetAccess(ptr, size, &resources->accessDesc, 1)); + CUDACHECK(cudaMemset((void*)ptr, 0, size)); + resources->ucBuff = (char*)ptr; + INFO(NCCL_NVLS, "NVLS Mapped UC at %p size %zi", resources->ucBuff, size); + + // Bind physical memory to the Multicast group + // NB: It will block until all ranks have been added to the Group + INFO(NCCL_NVLS, "NVLS Bind mem %p UC handle 0x%llx MC handle 0x%llx size %zi", (void*)ptr, resources->ucHandle, resources->mcHandle, size); + CUCHECK(cuMulticastBindMem(resources->mcHandle, 0/*mcOffset*/, resources->ucHandle, 0/*memOffset*/, size, 0/*flags*/)); + + return ncclSuccess; +} + +ncclResult_t nvlsGroupMapMem(struct ncclComm *comm, struct nvlsResources* resources) { + size_t size = resources->size; + CUdeviceptr ptr = 0; + + // Create a VA for the NVLS + CUCHECK(cuMemAddressReserve(&ptr, size, resources->granularity, 0U, 0)); + // Map the VA locally + CUCHECK(cuMemMap(ptr, size, 0, resources->mcHandle, 0)); + resources->mcBuff = (char*)ptr; + INFO(NCCL_NVLS, "NVLS Mapped MC buffer at %p size %zi", resources->mcBuff, size); + + // Having completed the BindMem we can now call SetAccess + // NB: It will block until all ranks have bound to the Group + CUCHECK(cuMemSetAccess((CUdeviceptr)resources->mcBuff, size, &resources->accessDesc, 1)); + + return ncclSuccess; +} + +ncclResult_t nvlsGroupUnmapMem(struct ncclComm *comm, struct nvlsResources* resources) { + size_t size; + CUdeviceptr ptr; + INFO(NCCL_NVLS, "NVLS Unmap mem UC handle 0x%llx(%p) MC handle 0x%llx(%p)", + resources->ucHandle, resources->ucBuff, resources->mcHandle, resources->mcBuff); + + // Release the UC memory and mapping + ptr = (CUdeviceptr)resources->ucBuff; + size = resources->size; + CUCHECK(cuMemUnmap(ptr, size)); + CUCHECK(cuMemAddressFree(ptr, size)); + CUCHECK(cuMemRelease(resources->ucHandle)); + + // Release the MC memory and mapping + ptr = (CUdeviceptr)resources->mcBuff; + size = resources->size; + CUCHECK(cuMemUnmap(ptr, size)); + CUCHECK(cuMemAddressFree(ptr, size)); + CUCHECK(cuMemRelease(resources->mcHandle)); + + return ncclSuccess; +} + +#include "bootstrap.h" +#include "channel.h" + +#define NVLS_MEM_ALIGN_SIZE (1 << 21) + +NCCL_PARAM(NvlsChannels, "NVLS_NCHANNELS", 16); + +NCCL_PARAM(NvlsEnable, "NVLS_ENABLE", 1); + +ncclResult_t ncclNvlsSetup(struct ncclComm* comm) { + if (!ncclParamNvlsEnable() || comm->localRanks <= 1 || comm->nNodes>1) return ncclSuccess; + CUdevice dev; + int driverVersion; + if (CUPFN(cuDeviceGet) == NULL) return ncclSuccess; + CUCHECK(cuDeviceGet(&dev, comm->cudaDev)); + CUDACHECK(cudaDriverGetVersion(&driverVersion)); + comm->nvlsSupport = 0; + // NVLS Multicast support requires CUDA12.1 UMD + KMD + if (CUPFN(cuMulticastCreate) != NULL && driverVersion >= 12010) { + CUCHECK(cuDeviceGetAttribute(&comm->nvlsSupport, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, dev)); + } + INFO(NCCL_INIT, "NVLS multicast support is %savailable on dev %d", comm->nvlsSupport ? "" : "not ", dev); + if (comm->nvlsSupport == 0) return ncclSuccess; + + int nChannels = comm->nvlsChannels = std::max(comm->minCTAs, std::min(comm->maxCTAs, (int)ncclParamNvlsChannels())); + int rank = comm->localRank, nranks = comm->localRanks; + + for (int c=0; cnvlsResources = resources; + + size_t buffSize = comm->buffSizes[NCCL_PROTO_SIMPLE]; + size_t memSize = NVLS_MEM_ALIGN_SIZE; + size_t nvlsPerRankSize = nChannels*2*(buffSize+memSize); + size_t nvlsTotalSize = nvlsPerRankSize*nranks; + + INFO(NCCL_INIT|NCCL_NVLS, "NVLS comm %p rank %d nranks %d buffSize %zi memSize %zi nvlsPerRankSize %zi nvlsTotalSize %zi", + comm, rank, nranks, buffSize, memSize, nvlsPerRankSize, nvlsTotalSize); + + char* nvlsShareableHandle = NULL; + NCCLCHECKGOTO(ncclCalloc(&nvlsShareableHandle, NVLS_HANDLE_SIZE), res, cleanup); + NCCLCHECKGOTO(nvlsGetProperties(comm, resources, dev, nranks, nvlsTotalSize), res, cleanup); + if (rank == 0) { + NCCLCHECKGOTO(nvlsGroupCreate(comm, resources, rank, nranks, nvlsShareableHandle), res, cleanup); + NCCLCHECKGOTO(bootstrapIntraNodeBroadcast(comm->bootstrap, comm->localRankToRank, rank, nranks, 0, nvlsShareableHandle, NVLS_HANDLE_SIZE), res, cleanup); + } else { + NCCLCHECKGOTO(bootstrapIntraNodeBroadcast(comm->bootstrap, comm->localRankToRank, rank, nranks, 0, nvlsShareableHandle, NVLS_HANDLE_SIZE), res, cleanup); + NCCLCHECKGOTO(nvlsGroupConnect(comm, resources, 0, nvlsShareableHandle), res, cleanup); + } + + NCCLCHECKGOTO(nvlsGroupAddDevice(comm, resources), res, cleanup); + NCCLCHECKGOTO(nvlsGroupBindMem(comm, resources), res, cleanup); + // Local intra-node barrier to ensure everyone has bound their memory to the group + NCCLCHECKGOTO(bootstrapBarrier(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, comm->localRankToRank[0]), res, cleanup); + NCCLCHECKGOTO(nvlsGroupMapMem(comm, resources), res, cleanup); + + for (int c=0; cchannels+c; + channel->nvls.nHeads = nranks; + for (int i=0; invls.up[i] = -1; + channel->nvls.down = comm->nRanks+1+comm->localRank; + channel->nvls.out = -1; // Network not yet implemented. + channel->nvls.headRank = comm->localRank; // Network not yet implemented. + } + + for (int r=0; rnRanks+1+r; + for (int c=0; cchannels+c; + channel->nvls.up[r] = nvlsPeer; + + char* mem = NULL; + struct ncclChannelPeer* peer = channel->peers+nvlsPeer; + + // Reduce UC -> MC + mem = resources->ucBuff + (r*2*nChannels+c)*(buffSize+memSize); + peer->send[0].transportComm = &nvlsTransport.send; + peer->send[0].conn.buffs[NCCL_PROTO_SIMPLE] = mem; + peer->send[0].conn.head = (uint64_t*)(mem+buffSize); + peer->send[0].conn.tail = (uint64_t*)(mem+buffSize+memSize/2); + mem = resources->mcBuff + (r*2*nChannels+c)*(buffSize+memSize); + peer->recv[1].transportComm = &nvlsTransport.recv; + peer->recv[1].conn.buffs[NCCL_PROTO_SIMPLE] = mem; + peer->recv[1].conn.head = (uint64_t*)(mem+buffSize); + peer->recv[1].conn.tail = (uint64_t*)(mem+buffSize+memSize/2); + peer->recv[1].conn.flags |= NCCL_NVLS_MIN_POLL; + + // Broadcast MC -> UC + mem = resources->ucBuff + ((r*2+1)*nChannels+c)*(buffSize+memSize); + peer->recv[0].transportComm = &nvlsTransport.recv; + peer->recv[0].conn.buffs[NCCL_PROTO_SIMPLE] = mem; + peer->recv[0].conn.head = (uint64_t*)(mem+buffSize); + peer->recv[0].conn.tail = (uint64_t*)(mem+buffSize+memSize/2); + mem = resources->mcBuff + ((r*2+1)*nChannels+c)*(buffSize+memSize); + peer->send[1].transportComm = &nvlsTransport.send; + peer->send[1].conn.buffs[NCCL_PROTO_SIMPLE] = mem; + peer->send[1].conn.head = (uint64_t*)(mem+buffSize); + peer->send[1].conn.tail = (uint64_t*)(mem+buffSize+memSize/2); + peer->send[1].conn.flags |= NCCL_NVLS_MIN_POLL; + + CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[nvlsPeer].send[0], &peer->send[0].conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), res, cleanup); + CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[nvlsPeer].recv[0], &peer->recv[0].conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), res, cleanup); + CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[nvlsPeer].send[1], &peer->send[1].conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), res, cleanup); + CUDACHECKGOTO(cudaMemcpyAsync(&comm->channels[c].devPeers[nvlsPeer].recv[1], &peer->recv[1].conn, sizeof(struct ncclConnInfo), cudaMemcpyHostToDevice, comm->hostStream.cudaStream), res, cleanup); + + /*INFO(NCCL_INIT|NCCL_NVLS, "Peer %d Channel %d MC buff %p/%p UC Buff %p/%p", + nvlsPeer, c, + resources->mcBuff + (r*2*nChannels+c)*(buffSize+memSize), + resources->mcBuff + ((r*2+1)*nChannels+c)*(buffSize+memSize), + resources->ucBuff + (r*2*nChannels+c)*(buffSize+memSize), + resources->ucBuff + ((r*2+1)*nChannels+c)*(buffSize+memSize));*/ + } + } + + free(nvlsShareableHandle); + return res; + +cleanup: + comm->nvlsSupport = 0; + free(nvlsShareableHandle); + return res; +} + +ncclResult_t ncclNvlsFree(struct ncclComm* comm) { + struct nvlsResources* resources = (struct nvlsResources*)comm->nvlsResources; + if (resources == NULL) return ncclSuccess; + NCCLCHECK(nvlsGroupUnbind(comm, resources)); + NCCLCHECK(nvlsGroupUnmapMem(comm, resources)); + free(resources); + comm->nvlsResources = NULL; + return ncclSuccess; +} + +#else + +/* + * Pre CUDA 12.1 stubs + */ + +ncclResult_t ncclNvlsSetup(struct ncclComm* comm) { + return ncclSuccess; +} + +ncclResult_t ncclNvlsFree(struct ncclComm* comm) { + return ncclSuccess; +} + +#endif /* CUDA_VERSION >= 12010 */ diff --git a/src/transport/p2p.cc b/src/transport/p2p.cc index 5b2e3b7571..d73d451e01 100644 --- a/src/transport/p2p.cc +++ b/src/transport/p2p.cc @@ -267,11 +267,11 @@ ncclResult_t p2pSendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, st if (intermediateRank == -1) { info->rank = myInfo->rank; if (myInfo->pidHash == peerInfo->pidHash && useMemcpy == 0) { - if (ncclParamP2pDirectDisable() == 0) send->conn.direct |= info->read ? NCCL_DIRECT_READ : NCCL_DIRECT_WRITE; + if (ncclParamP2pDirectDisable() == 0) send->conn.flags |= info->read ? NCCL_DIRECT_READ : NCCL_DIRECT_WRITE; INFO(NCCL_INIT|NCCL_P2P, "Channel %02d/%01d : %d[%lx] -> %d[%lx] via P2P/direct pointer%s comm %p nRanks %02d", channelId, connIndex, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, useReadStr, comm, comm->nRanks); } else { - send->conn.direct |= info->read ? NCCL_IPC_READ : NCCL_IPC_WRITE; + send->conn.flags |= info->read ? NCCL_IPC_READ : NCCL_IPC_WRITE; INFO(NCCL_INIT|NCCL_P2P,"Channel %02d/%01d : %d[%lx] -> %d[%lx] via P2P/IPC%s%s comm %p nRanks %02d", channelId, connIndex, myInfo->rank, myInfo->busId, peerInfo->rank, peerInfo->busId, useReadStr, useMemcpy ? "/CE" : "", comm, comm->nRanks); } @@ -284,11 +284,11 @@ ncclResult_t p2pSendSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, st NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_P2P, 1, info->rank, &send->proxyConn)); if (useMemcpy) { - NCCLCHECK(ncclProxyCall(&send->proxyConn, ncclProxyMsgSetup, NULL, 0, &resources->proxyInfo, sizeof(struct p2pProxyInfo))); + NCCLCHECK(ncclProxyCallBlocking(&send->proxyConn, ncclProxyMsgSetup, NULL, 0, &resources->proxyInfo, sizeof(struct p2pProxyInfo))); info->shmSize = resources->proxyInfo.shmSize; memcpy(info->shmName, resources->proxyInfo.shmName, sizeof(info->shmName)); } else { - NCCLCHECK(ncclProxyCall(&send->proxyConn, ncclProxyMsgSetup, &sendSize, sizeof(int), &info->p2pBuff, sizeof(struct ncclP2pBuff))); + NCCLCHECK(ncclProxyCallBlocking(&send->proxyConn, ncclProxyMsgSetup, &sendSize, sizeof(int), &info->p2pBuff, sizeof(struct ncclP2pBuff))); NCCLCHECK(p2pMap(myInfo, comm->peerInfo+info->rank, &info->p2pBuff, (void**)&resources->devMem, &resources->sendMemIpc)); } @@ -318,16 +318,16 @@ ncclResult_t p2pRecvSetup(struct ncclComm* comm, struct ncclTopoGraph* graph, st if (intermediateRank == -1) { info->rank = myInfo->rank; if (myInfo->pidHash == peerInfo->pidHash && useMemcpy == 0) { - if (ncclParamP2pDirectDisable() == 0) recv->conn.direct |= info->read ? NCCL_DIRECT_READ : NCCL_DIRECT_WRITE; + if (ncclParamP2pDirectDisable() == 0) recv->conn.flags |= info->read ? NCCL_DIRECT_READ : NCCL_DIRECT_WRITE; } else { - recv->conn.direct |= info->read ? NCCL_IPC_READ : NCCL_IPC_WRITE; + recv->conn.flags |= info->read ? NCCL_IPC_READ : NCCL_IPC_WRITE; } } else { info->rank = intermediateRank; } NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_P2P, 0, info->rank, &recv->proxyConn)); - NCCLCHECK(ncclProxyCall(&recv->proxyConn, ncclProxyMsgSetup, &recvSize, sizeof(int), &info->p2pBuff, sizeof(struct ncclP2pBuff))); + NCCLCHECK(ncclProxyCallBlocking(&recv->proxyConn, ncclProxyMsgSetup, &recvSize, sizeof(int), &info->p2pBuff, sizeof(struct ncclP2pBuff))); NCCLCHECK(p2pMap(myInfo, comm->peerInfo+info->rank, &info->p2pBuff, (void**)&resources->devMem, &resources->recvMemIpc)); return ncclSuccess; @@ -358,7 +358,7 @@ static ncclResult_t p2pSendConnect(struct ncclComm* comm, struct ncclConnect* co send->conn.sizesFifo = resources->proxyInfo.ceRecvMem->sizesFifo; send->conn.head = &resources->proxyInfo.devShm->sendMem.head; // Send SIMPLE buff to proxy, and replace it by local buffer - NCCLCHECK(ncclProxyCall(&send->proxyConn, ncclProxyMsgConnect, &send->conn.buffs[NCCL_PROTO_SIMPLE], sizeof(void*), NULL, 0)); + NCCLCHECK(ncclProxyCallBlocking(&send->proxyConn, ncclProxyMsgConnect, &send->conn.buffs[NCCL_PROTO_SIMPLE], sizeof(void*), NULL, 0)); send->conn.buffs[NCCL_PROTO_SIMPLE] = resources->proxyInfo.ceDevBuff; } else { send->conn.tail = &remDevMem->tail; diff --git a/src/transport/shm.cc b/src/transport/shm.cc index 689e01cfb4..e125df2c2f 100644 --- a/src/transport/shm.cc +++ b/src/transport/shm.cc @@ -157,7 +157,7 @@ static ncclResult_t shmSendConnect(struct ncclComm* comm, struct ncclConnect* co if (useMemcpySend) { NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_SHM, 1, comm->rank, &send->proxyConn)); struct shmProxyInfo proxyInfo = { NULL, NULL, send->conn.buffs[NCCL_PROTO_SIMPLE], resources->hostMem, resources->remHostMem }; - NCCLCHECK(ncclProxyCall(&send->proxyConn, ncclProxyMsgConnect, &proxyInfo, sizeof(struct shmProxyInfo), &proxyInfo, sizeof(struct shmProxyInfo))); + NCCLCHECK(ncclProxyCallBlocking(&send->proxyConn, ncclProxyMsgConnect, &proxyInfo, sizeof(struct shmProxyInfo), &proxyInfo, sizeof(struct shmProxyInfo))); send->conn.buffs[NCCL_PROTO_SIMPLE] = proxyInfo.devFifo; send->conn.tail = &proxyInfo.ceRecvMem->tail; send->conn.sizesFifo = proxyInfo.ceRecvMem->sizesFifo; @@ -187,7 +187,7 @@ static ncclResult_t shmRecvConnect(struct ncclComm* comm, struct ncclConnect* co if (useMemcpyRecv) { NCCLCHECK(ncclProxyConnect(comm, TRANSPORT_SHM, 0, comm->rank, &recv->proxyConn)); struct shmProxyInfo proxyInfo = { NULL, NULL, recv->conn.buffs[NCCL_PROTO_SIMPLE], resources->remHostMem, resources->hostMem }; - NCCLCHECK(ncclProxyCall(&recv->proxyConn, ncclProxyMsgConnect, &proxyInfo, sizeof(struct shmProxyInfo), &proxyInfo, sizeof(struct shmProxyInfo))); + NCCLCHECK(ncclProxyCallBlocking(&recv->proxyConn, ncclProxyMsgConnect, &proxyInfo, sizeof(struct shmProxyInfo), &proxyInfo, sizeof(struct shmProxyInfo))); recv->conn.buffs[NCCL_PROTO_SIMPLE] = proxyInfo.devFifo; recv->conn.tail = &proxyInfo.ceRecvMem->tail; }