From d009ab144eeb15800d7344e800cbdc9f8683ec88 Mon Sep 17 00:00:00 2001 From: Mustafa Abduljabbar Date: Thu, 11 Dec 2025 19:04:35 -0500 Subject: [PATCH] [Device] WarpSpeed enablement and single node CU and perf opt for MI350 (#2073) --- CHANGELOG.md | 3 + CMakeLists.txt | 18 +++++ install.sh | 9 ++- src/device/all_gather.h | 13 +++- src/device/all_reduce.h | 10 +++ src/device/broadcast.h | 9 +++ src/device/common.cu | 24 +++---- src/device/common.h | 77 +++++++++++++++++--- src/device/msccl_kernel_impl.h | 3 + src/device/prims_ll.h | 4 ++ src/device/prims_ll128.h | 4 ++ src/device/prims_simple.h | 27 +++++-- src/device/reduce.h | 9 +++ src/device/reduce_scatter.h | 9 +++ src/enqueue.cc | 47 +++++++++---- src/graph/connect.cc | 22 ++++-- src/graph/topo.h | 3 + src/include/comm.h | 9 ++- src/include/device.h | 32 +++++++-- src/include/rccl_common.h | 5 ++ src/init.cc | 28 +++++--- src/rccl_wrap.cc | 124 ++++++++++++++++++++++++++++++++- 22 files changed, 424 insertions(+), 65 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f9117b244..aa3c135b78 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ Full documentation for RCCL is available at [https://rccl.readthedocs.io](https: * RCCL error messages have been made more verbose in several cases. RCCL now prints out fatal error messages by default. Fatal error messages can be suppressed by setting `NCCL_DEBUG=NONE`. * Disabled `reduceCopyPacks` pipelining for `gfx950`. +* Experimental support for traffic shaping using warp specialization (also known as WarpSpeed) is now available for the Ring algorithm. +* Enabling WarpSpeed in auto mode using RCCL_WARP_SPEED_AUTO optimizes performance and reduces the CU count by 50% on a single node for AllReduce, AllGather from 64MB, and ReduceScatter from 256MB. +* The following configuration knobs control WarpSpeed behavior for debugging purposes: `RCCL_WARP_SPEED_ENABLE`, `RCCL_UNROLL_FACTOR`, `RCCL_WARP_SPEED_CU_COUNT`, and `RCCL_THREADS_PER_BLOCK`. Note that the effective unroll factor is calculated as 2 raised to the value of `RCCL_UNROLL_FACTOR`. ## Unreleased - RCCL 2.27.7 for ROCm 7.1.1 diff --git a/CMakeLists.txt b/CMakeLists.txt index 6b5cde31d3..8379b93bcb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -333,6 +333,7 @@ endif() ## Currently MSCCL++ is supported only on gfx942 and gfx950, and only on Ubuntu and CentOS set(MSCCLPP_SUPPORTED_ARCHS "gfx942" "gfx942:xnack-" "gfx942:xnack+" "gfx950" "gfx950:xnack-" "gfx950:xnack+") + # Check if any of the supported architectures are in GPU_TARGETS set(ARCH_MATCH_FOUND OFF) set(MSCCLPP_GPU_TARGETS "") @@ -355,6 +356,20 @@ if (ENABLE_MSCCLPP AND ROCM_VERSION VERSION_LESS "60200") message(WARNING "MSCCL++ integration only supported on ROCm 6.2.0 or greater; disabling MSCCL++ build") endif() +## Disable WARP_SPEED if the build environment is invalid +set(WARP_SPEED_SUPPORTED_ARCHS "gfx942" "gfx942:xnack-" "gfx942:xnack+" "gfx950" "gfx950:xnack-" "gfx950:xnack+") +set(ARCH_MATCH_FOUND OFF) +foreach(ARCH IN LISTS GPU_TARGETS) + if(ARCH IN_LIST WARP_SPEED_SUPPORTED_ARCHS) + set(ARCH_MATCH_FOUND ON) + endif() +endforeach() +if (NOT ARCH_MATCH_FOUND) + set(ENABLE_WARP_SPEED OFF) + message(WARNING "Can only build WARP_SPEED for supported GPU_TARGETS: ${WARP_SPEED_SUPPORTED_ARCHS}; current GPU_TARGETS: ${GPU_TARGETS}; so disabling WARP_SPEED build") +endif() + + # cmake_host_system_information(RESULT HOST_OS_ID QUERY DISTRIB_ID) ## Requires cmake 3.22 execute_process( COMMAND bash -c "grep '^ID=' /etc/os-release | cut -d'=' -f2 | cut -d'\"' -f2" @@ -875,6 +890,9 @@ endif() if(HAVE_ROCM_SMI_THREAD_ONLY_MUTEX) target_compile_definitions(rccl PRIVATE USE_ROCM_SMI_THREAD_ONLY_MUTEX) endif() +if(ENABLE_WARP_SPEED) + target_compile_definitions(rccl PRIVATE ENABLE_WARP_SPEED) +endif() # NPKit flags ## May be better to move these to a separate file diff --git a/install.sh b/install.sh index ac24aee79f..db62ad8dcc 100755 --- a/install.sh +++ b/install.sh @@ -39,6 +39,7 @@ run_tests_all=false time_trace=false force_reduce_pipeline=false generate_sym_kernels=false +warp_speed_enabled=true # note that this flag will be overridden to false for non MI350/MI300 platforms quiet_warnings=false # ################################################# @@ -90,7 +91,7 @@ function display_help() # check if we have a modern version of getopt that can handle whitespace and long parameters getopt -T if [[ "$?" -eq 4 ]]; then - GETOPT_PARSE=$(getopt --name "${0}" --options cdfhij:lprtq --longoptions address-sanitizer,dependencies,debug,dump-asm,enable-code-coverage,enable_backtrace,disable-colltrace,disable-msccl-kernel,enable-mscclpp,fast,help,install,jobs:,kernel-resource-use,local_gpu_only,amdgpu_targets:,no_clean,npkit-enable,log-trace,openmp-test-enable,roctx-enable,package_build,prefix:,rm-legacy-include-dir,run_tests_all,run_tests_quick,static,tests_build,time-trace,force-reduce-pipeline,generate-sym-kernels,quiet-warnings,verbose -- "$@") + GETOPT_PARSE=$(getopt --name "${0}" --options cdfhij:lprtq --longoptions address-sanitizer,dependencies,debug,dump-asm,enable-code-coverage,enable_backtrace,disable-colltrace,disable-msccl-kernel,enable-mscclpp,fast,help,install,jobs:,kernel-resource-use,local_gpu_only,amdgpu_targets:,no_clean,npkit-enable,log-trace,openmp-test-enable,roctx-enable,package_build,prefix:,rm-legacy-include-dir,run_tests_all,run_tests_quick,static,tests_build,time-trace,force-reduce-pipeline,generate-sym-kernels,quiet-warnings,disable-warp-speed,verbose -- "$@") else echo "Need a new version of getopt" exit 1 @@ -137,6 +138,7 @@ while true; do --verbose) build_verbose=true; shift ;; --force-reduce-pipeline) force_reduce_pipeline=true; shift ;; --generate-sym-kernels) generate_sym_kernels=true; shift ;; + --disable-warp-speed) warp_speed_enabled=false; shift ;; -q | --quiet-warnings) quiet_warnings=true; shift ;; --) shift ; break ;; *) echo "Unexpected command line parameter received; aborting"; @@ -316,6 +318,11 @@ if [[ "${npkit_enabled}" == true ]]; then cmake_common_options="${cmake_common_options} -DENABLE_NPKIT=ON" fi +# Enable WARP_SPEED only on MI350/MI300 platforms +if [[ "${warp_speed_enabled}" == true ]]; then + cmake_common_options="${cmake_common_options} -DENABLE_WARP_SPEED=ON" +fi + # Suppress Warnings if [[ "${quiet_warnings}" == true ]]; then cmake_common_options="${cmake_common_options} -DQUIET_WARNINGS=ON" diff --git a/src/device/all_gather.h b/src/device/all_gather.h index 2dc1f0417e..3408da2549 100644 --- a/src/device/all_gather.h +++ b/src/device/all_gather.h @@ -20,11 +20,20 @@ namespace { const int bid = ncclShmem.channelId - work->channelLo; int npKitCtxIdx = bid; // unused variable - compiler warning #endif +#ifdef ENABLE_WARP_SPEED + int warp = threadIdx.x / WARP_SIZE; + ncclRing *ring = &ncclShmem.warpChannel[warp].ring; +#else ncclRing *ring = &ncclShmem.channel.ring; +#endif const int *ringRanks = ring->userRanks; const int nranks = ncclShmem.comm.nRanks; ssize_t count, partOffset, partCount, chunkCount; +#ifdef ENABLE_WARP_SPEED + ncclCollCbdPart(work, ncclShmem.warpChannelId[warp], Proto::Id, sizeof(T), &count, &partOffset, &partCount, &chunkCount); +#else ncclCollCbdPart(work, ncclShmem.channelId, Proto::Id, sizeof(T), &count, &partOffset, &partCount, &chunkCount); +#endif ssize_t offset; ssize_t dataOffset; int nelem; @@ -142,7 +151,7 @@ namespace { #endif // Final wait/copy. prims.directRecv(offset, nelem); - + #if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_ALL_GATHER_RING_DIRECT_RECV_EXIT) if (tid == 0) { NpKit::CollectGpuEvent(NPKIT_EVENT_ALL_GATHER_RING_DIRECT_RECV_EXIT, nelem*sizeof(T), prims.npKitDataProcessTotalTime, NPKIT_GET_GPU_TIMESTAMP(), @@ -671,4 +680,4 @@ struct RunWorkCollindex; + const int nranks = ncclShmem.comm.nRanks; #if defined(ENABLE_NPKIT) const int bid = ncclShmem.channelId - work->channelLo; @@ -31,7 +37,11 @@ namespace { ssize_t gridOffset; ssize_t channelCount; ssize_t chunkCount; +#ifdef ENABLE_WARP_SPEED + ncclCollCbdPart(work, ncclShmem.warpChannelId[warp], Proto::Id, sizeof(T), &size, &gridOffset, &channelCount, &chunkCount); +#else ncclCollCbdPart(work, ncclShmem.channelId, Proto::Id, sizeof(T), &size, &gridOffset, &channelCount, &chunkCount); +#endif const ssize_t loopCount = nranks * chunkCount; ssize_t offset; int nelem; diff --git a/src/device/broadcast.h b/src/device/broadcast.h index 364e87ee2b..55352c9be3 100644 --- a/src/device/broadcast.h +++ b/src/device/broadcast.h @@ -19,7 +19,12 @@ namespace { const int bid = ncclShmem.channelId - work->channelLo; int npKitCtxIdx = bid; // unused variable - compiler warning #endif +#ifdef ENABLE_WARP_SPEED + int warp = threadIdx.x / WARP_SIZE; + ncclRing *ring = &ncclShmem.warpChannel[warp].ring; +#else ncclRing *ring = &ncclShmem.channel.ring; +#endif const int rank = ring->userRanks[0]; const int nextRank = ring->userRanks[1]; const int root = work->root; @@ -27,7 +32,11 @@ namespace { ssize_t chunkCount; ssize_t channelCount; ssize_t gridOffset; +#ifdef ENABLE_WARP_SPEED + ncclCollCbdPart(work, ncclShmem.warpChannelId[warp], Proto::Id, sizeof(T), &size, &gridOffset, &channelCount, &chunkCount); +#else ncclCollCbdPart(work, ncclShmem.channelId, Proto::Id, sizeof(T), &size, &gridOffset, &channelCount, &chunkCount); +#endif size_t offset; int nelem; int workNthreads; diff --git a/src/device/common.cu b/src/device/common.cu index 36d396fbb8..70a7dc64ef 100644 --- a/src/device/common.cu +++ b/src/device/common.cu @@ -17,24 +17,24 @@ struct RunWorkNop { __device__ void run() {} }; -__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic_1(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) { - ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/false, /*Unroll*/1>(&args4K.args); +__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic_1(ncclDevKernelArgsDefaultStorage NCCL_GRID_CONSTANT const argsStorage) { + ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/false, /*Unroll*/1>(&argsStorage.args); } -__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic_2(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) { - ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/false, /*Unroll*/2>(&args4K.args); +__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic_2(ncclDevKernelArgsDefaultStorage NCCL_GRID_CONSTANT const argsStorage) { + ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/false, /*Unroll*/2>(&argsStorage.args); } -__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) { - ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/false, /*Unroll*/4>(&args4K.args); +__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernel_Generic_4(ncclDevKernelArgsDefaultStorage NCCL_GRID_CONSTANT const argsStorage) { + ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/false, /*Unroll*/4>(&argsStorage.args); } #ifdef ENABLE_COLLTRACE -__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic_1(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) { - ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/true, /*Unroll*/1>(&args4K.args); +__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic_1(ncclDevKernelArgsDefaultStorage NCCL_GRID_CONSTANT const argsStorage) { + ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/true, /*Unroll*/1>(&argsStorage.args); } -__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic_2(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) { - ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/true, /*Unroll*/2>(&args4K.args); +__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic_2(ncclDevKernelArgsDefaultStorage NCCL_GRID_CONSTANT const argsStorage) { + ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/true, /*Unroll*/2>(&argsStorage.args); } -__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) { - ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/true, /*Unroll*/4>(&args4K.args); +__launch_bounds__(NCCL_MAX_NTHREADS, 1) __global__ void ncclDevKernelDebug_Generic_4(ncclDevKernelArgsDefaultStorage NCCL_GRID_CONSTANT const argsStorage) { + ncclKernelMain<-1, RunWorkNop, /*COLLTRACE*/true, /*Unroll*/4>(&argsStorage.args); } #endif diff --git a/src/device/common.h b/src/device/common.h index ada2eba98f..e526d85c14 100644 --- a/src/device/common.h +++ b/src/device/common.h @@ -138,7 +138,11 @@ struct ncclShmemData { int aborted; alignas(16) struct ncclDevComm comm; alignas(16) struct ncclDevChannel channel; - +#ifdef ENABLE_WARP_SPEED + int warpComm; + alignas(16) struct ncclDevChannel warpChannel[NCCL_MAX_GROUPS]; + int warpChannelId[NCCL_MAX_GROUPS]; +#endif int batchIx, nextBatchIx; enum ncclDevWorkType workType; uint8_t directMode; @@ -442,10 +446,17 @@ struct RunWorkBatch { if (work->nWarps != workPrev->nWarps) __syncthreads(); } int subtn = work->nWarps*WARP_SIZE; +#ifdef ENABLE_WARP_SPEED + if (tid < subtn) { + if(ncclShmem.warpComm == 0 || Algo != NCCL_ALGO_RING) RunWorkColl().run(tid, subtn, work); + else if (ncclShmem.warpChannelId[tid / WARP_SIZE] >= 0) RunWorkColl().run(tid % WARP_SIZE, WARP_SIZE, work); + } +#else // Coverity reports a possible thread divergence due to not all threads participating in the collective. // However, the code ensures that the participation is on a per-warp basis. // coverity[device_thread_diverged:FALSE] if (tid < subtn) RunWorkColl().run(tid, subtn, work); +#endif } } }; @@ -489,7 +500,12 @@ __device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* a int x = tid; int total = 0, y; int num = MAXCHANNELS/64 > 0 ? MAXCHANNELS/64 : 1; - +#ifdef ENABLE_WARP_SPEED + int warpCount = tn / WARP_SIZE; + int localWarpId = tid / WARP_SIZE; + int globalWarpId = (warpCount * blockIdx.x) + localWarpId; + int laneId = tid % WARP_SIZE; +#endif // Copy kernel args to shmem and then only read those. Otherwise the compiler // will end up putting the args into thread local stack which is very wasteful. if (tid < sizeof(ncclDevKernelArgs)/sizeof(uint32_t)) { @@ -583,9 +599,52 @@ __device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* a ncclShmem.collTrace = args->comm->collTrace + COLLTRACE_NUM_ITEMS*ncclShmem.channelId; ncclShmem.collTraceTail = args->comm->collTraceTail + ncclShmem.channelId; } +#endif +#ifdef ENABLE_WARP_SPEED + if(tid == 0) { + ncclShmem.warpComm = args->comm->warpLevelComm; + } #endif __syncthreads(); // publish shmem +#ifdef ENABLE_WARP_SPEED + // Determine per-warp channel assignment for WarpSpeed enablement + total = 0; + if(ncclShmem.warpComm == 1) { // If warpComm is enabled, assign warps to channels that have the corresponding channel mask enabled + ncclShmem.warpChannelId[localWarpId] = -1; + __syncthreads(); + for (int i = 0; i < num; i++) { + if (args->channelMask.masks[i] & (1ull<channelMask.masks[i] & ((1ull<channelMask.masks[i]); + } + __syncthreads(); + if(ncclShmem.warpChannelId[localWarpId] >= 0) { + void* dst = &ncclShmem.warpChannel[localWarpId]; + void* src = &((ncclDevCommAndChannels*)ncclShmem.args.comm)->channels[ncclShmem.warpChannelId[localWarpId]]; + int bytes = sizeof(ncclDevChannel); + static_assert(sizeof(ncclDevChannel) <= 16*WARP_SIZE, "ncclDevChannel cannot be loaded by a single warp in one insn."); + // assert((tid-localWarpId*WARP_SIZE) >= 0 && (tid-localWarpId*WARP_SIZE) < WARP_SIZE); + copyToShmem16(tid-localWarpId*WARP_SIZE, dst, src, bytes); + } + } else { // If warpComm is disabled, all warps use the same channel as the block + if(laneId == 0) { + ncclShmem.warpChannelId[localWarpId] = ncclShmem.channelId; + } + // Use all threads in the warp to copy the channel data in parallel + void* dst = &ncclShmem.warpChannel[localWarpId]; + void* src = &ncclShmem.channel; + int bytes = sizeof(ncclDevChannel); + copyToShmem16(laneId, dst, src, bytes); + } + __syncthreads(); +#endif #ifdef ENABLE_PROFILING if (tid == 0) { ncclShmem.prof.count = 0; @@ -648,17 +707,17 @@ __device__ __forceinline__ void ncclKernelMain(struct ncclDevKernelArgs const* a #endif } -__global__ void ncclDevKernel_Generic_1(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K); -__global__ void ncclDevKernel_Generic_2(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K); -__global__ void ncclDevKernel_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K); +__global__ void ncclDevKernel_Generic_1(ncclDevKernelArgsDefaultStorage NCCL_GRID_CONSTANT const argsStorage); +__global__ void ncclDevKernel_Generic_2(ncclDevKernelArgsDefaultStorage NCCL_GRID_CONSTANT const argsStorage); +__global__ void ncclDevKernel_Generic_4(ncclDevKernelArgsDefaultStorage NCCL_GRID_CONSTANT const argsStorage); #ifdef ENABLE_COLLTRACE -__global__ void ncclDevKernelDebug_Generic_1(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K); -__global__ void ncclDevKernelDebug_Generic_2(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K); -__global__ void ncclDevKernelDebug_Generic_4(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K); +__global__ void ncclDevKernelDebug_Generic_1(ncclDevKernelArgsDefaultStorage NCCL_GRID_CONSTANT const argsStorage); +__global__ void ncclDevKernelDebug_Generic_2(ncclDevKernelArgsDefaultStorage NCCL_GRID_CONSTANT const argsStorage); +__global__ void ncclDevKernelDebug_Generic_4(ncclDevKernelArgsDefaultStorage NCCL_GRID_CONSTANT const argsStorage); #endif #define DEFINE_ncclDevKernel_nop(suffix, coll, redop, ty, algo, proto, specializedFnId) \ - __global__ void ncclDevKernel_##suffix(ncclDevKernelArgs4K NCCL_GRID_CONSTANT const args4K) {} + __global__ void ncclDevKernel_##suffix(ncclDevKernelArgsDefaultStorage NCCL_GRID_CONSTANT const argsStorage) {} #ifdef USE_INDIRECT_FUNCTION_CALL #define DEFINE_ncclDevFunc(suffix, coll, redop, ty, algo, proto, acc, pipeline, unroll) \ diff --git a/src/device/msccl_kernel_impl.h b/src/device/msccl_kernel_impl.h index 87ba510ba4..3d2a2eda26 100644 --- a/src/device/msccl_kernel_impl.h +++ b/src/device/msccl_kernel_impl.h @@ -146,6 +146,9 @@ __device__ __forceinline__ void mscclRunInterpreter( } if (bytes) copyToShmem8(tid%WARP_SIZE, dst, src, bytes); } +#ifdef ENABLE_WARP_SPEED + ncclShmem.warpComm = 0; +#endif __syncthreads(); // publish shmem #if defined(ENABLE_NPKIT) diff --git a/src/device/prims_ll.h b/src/device/prims_ll.h index 1b21da8b2d..b7e482b175 100644 --- a/src/device/prims_ll.h +++ b/src/device/prims_ll.h @@ -654,7 +654,11 @@ public: redOp(redOpArg), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), group(group), threadsPerBlock(blockDim.x), stepLines(ncclShmem.comm.buffSizes[NCCL_PROTO_LL]/NCCL_STEPS/sizeof(ncclLLFifoLine)) { +#ifdef ENABLE_WARP_SPEED + auto *channel = isMsccl(Metadata) ? &ncclShmem.channel : &ncclShmem.warpChannel[threadIdx.x / WARP_SIZE]; +#else auto *channel = &ncclShmem.channel; +#endif barriers = &ncclShmem.groups[group].barrier; // If we are going to support oneshot collNet + LL, then we would need to add connector index here int nrecv=0, nsend=0; diff --git a/src/device/prims_ll128.h b/src/device/prims_ll128.h index 42024b5548..283868fc06 100644 --- a/src/device/prims_ll128.h +++ b/src/device/prims_ll128.h @@ -579,7 +579,11 @@ public: tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), /*compiler warnings*/ stepSize(ncclShmem.comm.buffSizes[NCCL_PROTO_LL128]/NCCL_STEPS/sizeof(uint64_t)), warp(tid/WARP_SIZE), warpInBlock(threadIdx.x/WARP_SIZE), flagThread((tid%4)==3), group(group), threadsPerBlock(blockDim.x){ +#ifdef ENABLE_WARP_SPEED + auto *channel = isMsccl(Metadata) ? &ncclShmem.channel : &ncclShmem.warpChannel[warpInBlock]; +#else auto *channel = &ncclShmem.channel; +#endif barriers = &ncclShmem.groups[group].barrier; int nrecv=0, nsend=0; while (nrecv < MaxRecv && recvPeers[nrecv] >= 0) { diff --git a/src/device/prims_simple.h b/src/device/prims_simple.h index 6a7260f7a2..b5de1def50 100644 --- a/src/device/prims_simple.h +++ b/src/device/prims_simple.h @@ -502,14 +502,22 @@ private: public: static inline __device__ void sendPeerNotify(int peer, int connIndex, int steps) { +#ifdef ENABLE_WARP_SPEED + ncclDevChannelPeer* peerPtr = ncclShmem.warpChannel[threadIdx.x/WARP_SIZE].peers[peer]; +#else ncclDevChannelPeer* peerPtr = ncclShmem.channel.peers[peer]; +#endif peerPtr->send[connIndex].step += steps; st_relaxed_sys_global(peerPtr->send[connIndex].tail, peerPtr->send[connIndex].step); } static inline __device__ void recvPeerNotify(int peer, int connIndex, int steps) { int spins = 0; +#ifdef ENABLE_WARP_SPEED + ncclDevChannelPeer* peerPtr = ncclShmem.warpChannel[threadIdx.x/WARP_SIZE].peers[peer]; +#else ncclDevChannelPeer* peerPtr = ncclShmem.channel.peers[peer]; +#endif peerPtr->recv[connIndex].step += steps; st_relaxed_sys_global(peerPtr->recv[connIndex].head, peerPtr->recv[connIndex].step); while (ld_volatile_global(peerPtr->recv[connIndex].tail) < peerPtr->recv[connIndex].step) { @@ -770,13 +778,20 @@ public: struct ncclDevWorkP2p* p2pWork = nullptr, int stepSize_ = 0, int mode = primsModeDefault ): tid(tid), tidInBlock(threadIdx.x), nthreads(nthreads), /*compiler warnings*/ +#ifdef ENABLE_WARP_SPEED + stepSize(stepSize_ == 0 ? ncclShmem.comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/sizeof(T) : stepSize_), group(ncclShmem.warpComm? tidInBlock / WARP_SIZE : group), threadsPerBlock(blockDim.x){ +#else stepSize(stepSize_ == 0 ? ncclShmem.comm.buffSizes[NCCL_PROTO_SIMPLE]/NCCL_STEPS/sizeof(T) : stepSize_), group(group), threadsPerBlock(blockDim.x){ - +#endif barriers = &ncclShmem.groups[group].barrier; // PAT uses the same barrier for each group barriers_pat = &ncclShmem.barrier_pat; this->nworkers = nthreads; - +#ifdef ENABLE_WARP_SPEED + auto *channel = isMsccl(Metadata) ? &ncclShmem.channel : &ncclShmem.warpChannel[tidInBlock/WARP_SIZE]; +#else + auto *channel = &ncclShmem.channel; +#endif int peer = -1; flags = 0; index = -1; @@ -831,9 +846,9 @@ public: } // coverity[overrun-call] => Coverity think prims.index can be greater than 1 - if (flags & (RoleWaitRecv|RolePostRecv)) loadRecvConn(ncclShmem.channel.peers[peer], connIndexRecv, collWork ? collWork->direct : 0, recvIpcReg, recvNetReg); + if (flags & (RoleWaitRecv|RolePostRecv)) loadRecvConn(channel->peers[peer], connIndexRecv, collWork ? collWork->direct : 0, recvIpcReg, recvNetReg); // coverity[overrun-call] => Coverity think prims.index can be greater than 1 - if (flags & (RoleWaitSend|RolePostSend)) loadSendConn(ncclShmem.channel.peers[peer], connIndexSend, collWork ? collWork->direct : 0, sendIpcReg, sendNetReg); + if (flags & (RoleWaitSend|RolePostSend)) loadSendConn(channel->peers[peer], connIndexSend, collWork ? collWork->direct : 0, sendIpcReg, sendNetReg); // if (barrierAny(flags & NetDeviceUnpack)) { // flags |= AnyNetDeviceUnpack; @@ -861,7 +876,7 @@ public: // Load recv peer int recvPeer = mode == primsModePatRs ? (rank - delta + nranks) % nranks : (rank + delta) % nranks; struct ncclPatPeer* peer = ((struct ncclPatPeer*)recvPeers)+tid; - struct ncclConnInfo* conn = peer->conn = ncclShmem.channel.peers[recvPeer]->recv+connIndexRecv; + struct ncclConnInfo* conn = peer->conn = channel->peers[recvPeer]->recv+connIndexRecv; peer->step = conn->step; peer->buff = conn->buffs[NCCL_PROTO_SIMPLE]; peer->stepCache = loadStepValue(peer->tailPtr = conn->tail); @@ -871,7 +886,7 @@ public: // Load send peer int sendPeer = mode == primsModePatAg ? (rank - delta + nranks) % nranks : (rank + delta) % nranks; peer = ((struct ncclPatPeer*)sendPeers)+tid; - conn = peer->conn = ncclShmem.channel.peers[sendPeer]->send+connIndexSend; + conn = peer->conn = channel->peers[sendPeer]->send+connIndexSend; peer->step = conn->step; peer->connFifo = conn->connFifo; peer->buff = conn->buffs[NCCL_PROTO_SIMPLE]; diff --git a/src/device/reduce.h b/src/device/reduce.h index 4ca3fb28cb..efc39df93e 100644 --- a/src/device/reduce.h +++ b/src/device/reduce.h @@ -16,7 +16,12 @@ namespace { #else __device__ __attribute__((noinline)) void runRing(int tid, int nthreads, struct ncclDevWorkColl* work) { #endif +#ifdef ENABLE_WARP_SPEED + int warp = threadIdx.x / WARP_SIZE; + ncclRing *ring = &ncclShmem.warpChannel[warp].ring; +#else ncclRing *ring = &ncclShmem.channel.ring; +#endif const int nranks = ncclShmem.comm.nRanks; const int rank = ncclShmem.comm.rank; const int prevRank = ring->userRanks[nranks-1]; @@ -24,7 +29,11 @@ namespace { size_t chunkCount; size_t channelCount; size_t gridOffset; +#ifdef ENABLE_WARP_SPEED + ncclCollCbdPart(work, ncclShmem.warpChannelId[warp], Proto::Id, sizeof(T), (size_t*)nullptr, &gridOffset, &channelCount, &chunkCount); +#else ncclCollCbdPart(work, ncclShmem.channelId, Proto::Id, sizeof(T), (size_t*)nullptr, &gridOffset, &channelCount, &chunkCount); +#endif size_t offset; int nelem; diff --git a/src/device/reduce_scatter.h b/src/device/reduce_scatter.h index e5c6896143..8ce151b66f 100644 --- a/src/device/reduce_scatter.h +++ b/src/device/reduce_scatter.h @@ -16,14 +16,23 @@ namespace { #else __device__ __attribute__((noinline)) void runRing(int tid, int nthreads, struct ncclDevWorkColl* work) { #endif +#ifdef ENABLE_WARP_SPEED + int warp = threadIdx.x / WARP_SIZE; + ncclRing *ring = &ncclShmem.warpChannel[warp].ring; +#else ncclRing *ring = &ncclShmem.channel.ring; +#endif int const *ringRanks = ring->userRanks; const int nranks = ncclShmem.comm.nRanks; size_t count; size_t gridOffset; size_t channelCount; size_t chunkCount; +#ifdef ENABLE_WARP_SPEED + ncclCollCbdPart(work, ncclShmem.warpChannelId[warp], Proto::Id, sizeof(T), &count, &gridOffset, &channelCount, &chunkCount); +#else ncclCollCbdPart(work, ncclShmem.channelId, Proto::Id, sizeof(T), &count, &gridOffset, &channelCount, &chunkCount); +#endif size_t offset; size_t dataOffset; uint32_t nelem; diff --git a/src/enqueue.cc b/src/enqueue.cc index 6f1d5c2277..f8733818c3 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -390,7 +390,7 @@ ncclResult_t ncclTasksRegAndEnqueue(struct ncclComm* comm) { devWork.rcclUseOneSlice = comm->rcclUseOneSlice; //[Added-comment] opCount is missing for collDevWork, adding here devWork.opCount = task->opCount; - + devWork.isOneRPN = comm->isOneRPN; devWork.netRegUsed = devWork.regUsed = 0; devWork.gfx9CheapFenceOff = gfx9CheapFenceOff(devWork, comm->gfx9CheapFenceOff); @@ -513,11 +513,11 @@ ncclResult_t ncclPrepareTasks(struct ncclComm* comm, bool* algoNeedConnect, bool WARN("%s: unsupported collective. Please ensure the collective has been enabled in build.", __func__); return ncclInvalidUsage; } - + if (!rcclIsArchSupportedForFunc(&agg, comm->archName)) { - WARN("%s: unsupported architecture (%s) for collective %s(%s, %s, %s, %s, Acc=%d, Pipeline=%d).", - __func__, comm->archName, - ncclFuncToString(task->func), ncclAlgoToString(task->algorithm), ncclProtoToString(task->protocol), + WARN("%s: unsupported architecture (%s) for collective %s(%s, %s, %s, %s, Acc=%d, Pipeline=%d).", + __func__, comm->archName, + ncclFuncToString(task->func), ncclAlgoToString(task->algorithm), ncclProtoToString(task->protocol), ncclDevRedOpToString(task->opDev.op), ncclDatatypeToString(task->datatype), (agg.acc != nullptr), agg.pipeline); return ncclInvalidUsage; } @@ -541,6 +541,9 @@ ncclResult_t ncclPrepareTasks(struct ncclComm* comm, bool* algoNeedConnect, bool aggBeg->protocol = agg.protocol; aggBeg->acc = agg.acc; aggBeg->pipeline = agg.pipeline; +#ifdef ENABLE_WARP_SPEED + aggBeg->useWarpSpeed = agg.useWarpSpeed; +#endif if (aggBeg->protocol == NCCL_PROTO_LL) aggBeg->trafficBytes *= 4; aggBeg->nMaxChannels = agg.nMaxChannels; aggBeg->nWarps = agg.nWarps; @@ -764,10 +767,12 @@ static ncclResult_t scheduleCollTasksToPlan( (countHi != 0 ? countHi : countLo) -= cells*elementsPerCell - task->count; nChannels = (countLo!=0 ? 1 : 0) + nMidChannels + (cellsHi!=0 ? 1 : 0); - // Update number of channels propagated to the profiler - task->nChannels = (uint8_t)nChannels; - +#ifdef ENABLE_WARP_SPEED + task->nChannels = nChannels; +#else + task->nChannels = (uint8_t) nChannels; +#endif // Ensure room for worst case of one new batch per channel if (!testBudget(budget, plan->nWorkBatches + nChannels, plan->workBytes + workNode->size)) { return ncclSuccess; @@ -1756,7 +1761,6 @@ NCCL_PARAM(MemSyncDomain, "MEM_SYNC_DOMAIN", cudaLaunchMemSyncDomainRemote); #endif NCCL_PARAM(NvlinkUtilCentricSchedEnable, "NVLINK_UTIL_CENTRIC_SCHED_ENABLE", 0); - ncclResult_t ncclLaunchKernel(struct ncclComm* comm, struct ncclKernelPlan* plan) { ncclResult_t ret = ncclSuccess; struct ncclKernelPlanner* planner = &comm->planner; @@ -1764,6 +1768,9 @@ ncclResult_t ncclLaunchKernel(struct ncclComm* comm, struct ncclKernelPlan* plan for (int i = 0; i < MAXCHANNELS/64; i++) nChannels += countOneBits(plan->channelMask.masks[i]); void* sym = plan->kernelFn; +#ifdef ENABLE_WARP_SPEED + rcclSetWarpSpeedSupportAndFinalCuCount(comm, plan, nChannels, plan->kernelArgs->comm->warpLevelComm, nChannels); +#endif dim3 grid = {(unsigned)nChannels, 1, 1}; dim3 block = {(unsigned)plan->threadPerBlock, 1, 1}; int smem = rcclShmemDynamicSize(comm->cudaArch, comm->WarpSize); @@ -1883,8 +1890,8 @@ ncclResult_t ncclLaunchKernelAfter_NoCuda(struct ncclComm* comm, struct ncclKern // hostStreamPlanTask directly NCCLCHECK(hostStreamPlanTask(comm, plan)); } - - // Increment the opCount for intranode comms as well. Previously if proxyOpQueue was empty + + // Increment the opCount for intranode comms as well. Previously if proxyOpQueue was empty // opCount was not incremented because ncclProxyStart wasn't called in hostStreamPlanTask if (!plan->persistent && ncclIntruQueueHead(&plan->proxyOpQueue) == nullptr) { comm->opCount++; @@ -2095,8 +2102,11 @@ static ncclResult_t topoGetAlgoInfo( rcclSetPipelining(comm, nBytes, info); if (simInfo) simInfo->estimatedTime = time; TRACE(NCCL_COLL, "%ld Bytes -> Algo %d proto %d time %f", nBytes, info->algorithm, info->protocol, time); - +#ifdef ENABLE_WARP_SPEED + int nc = comm->topo->warpSpeedEnabled? comm->nChannels / 2 : comm->nChannels; +#else int nc = comm->nChannels; +#endif int nt = comm->maxThreads[info->algorithm][info->protocol]; int threadThreshold = comm->threadThresholds[info->algorithm][info->protocol]; if (info->algorithm == NCCL_ALGO_COLLNET_DIRECT) { @@ -2167,7 +2177,13 @@ static ncclResult_t topoGetAlgoInfo( } } else if (info->func == ncclFuncAllReduce && comm->topo->treeDefined == 1) { info->algorithm = NCCL_ALGO_TREE; +#ifdef ENABLE_WARP_SPEED + nc = std::min(nc, 64); // Tree uses at most 64 channels as we don't support WarpSpeed Tree. + } else if (info->algorithm == NCCL_ALGO_TREE) { + nc = std::min(nc, 64); // Tree uses at most 64 channels as we don't support WarpSpeed Tree. +#else info->nMaxChannels = nc; +#endif } else { info->nMaxChannels = nc; } @@ -2180,6 +2196,13 @@ static ncclResult_t topoGetAlgoInfo( info->nWarps = nt/comm->WarpSize; rcclOverrideAlgorithm(ncclAlgoStr, table, info); rcclOverrideProtocol(ncclProtoStr, table, info); +#ifdef ENABLE_WARP_SPEED + rcclSetWarpSpeedAuto(comm, info, nBytes); + if(info->useWarpSpeed) { + rcclSetWarpSpeedCUs(comm, info->algorithm, info->nWarps * comm->WarpSize, nc); + } + info->nMaxChannels = nc; +#endif return ncclSuccess; } diff --git a/src/graph/connect.cc b/src/graph/connect.cc index e70b496cc6..ac03b68e3a 100644 --- a/src/graph/connect.cc +++ b/src/graph/connect.cc @@ -721,6 +721,7 @@ ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePa int shared = parent && parent->nvlsSupport && parent->shareResources; int maxChannels; int minNchannels, maxNchannels; + int duplicateCount = 1; NCCLCHECK(ncclCalloc(&ringRecv, nNodes*MAXCHANNELS)); NCCLCHECKGOTO(ncclCalloc(&ringSend, nNodes*MAXCHANNELS), ret, fail); NCCLCHECKGOTO(ncclCalloc(&ringPrev, nranks*MAXCHANNELS), ret, fail); @@ -804,19 +805,30 @@ ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePa } } +#ifdef ENABLE_WARP_SPEED + // Only use full MAXCHANNELS for gfx942 (MI300X) and gfx950 + maxChannels = (IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx942") || + IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx950")) + ? MAXCHANNELS : 2*CHANNEL_LIMIT; + +#else // Only use full MAXCHANNELS for gfx942 (MI300X) and gfx950 maxChannels = (IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx942") || IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx950")) ? std::min(comm->topo->nodes[GPU].nodes[0].gpu.cu, MAXCHANNELS) : 2*CHANNEL_LIMIT; - if (graphs[NCCL_ALGO_RING]->nIntraChannels > 0 || comm->nNodes > 1) { maxChannels = std::min(64, maxChannels); } - +#endif // Duplicate ringPrev/ringNext for ncclBuildRing - if (nChannels <= maxChannels/2) memcpy(ringPrev+nChannels*nranks, ringPrev, nChannels*nranks*sizeof(int)); - if (nChannels <= maxChannels/2) memcpy(ringNext+nChannels*nranks, ringNext, nChannels*nranks*sizeof(int)); - + duplicateCount = maxChannels / nChannels; + if (duplicateCount > 1) { + int limit = duplicateCount; + for (int dup = 1; dup < limit; ++dup) { + memcpy(ringPrev + dup * nChannels * nranks, ringPrev, nChannels * nranks * sizeof(int)); + memcpy(ringNext + dup * nChannels * nranks, ringNext, nChannels * nranks * sizeof(int)); + } + } // Get number of channels after duplication maxNchannels = std::min((int)ncclMaxNchannels(), maxChannels); nc = std::min(maxNchannels/comm->nChannels, nc); diff --git a/src/graph/topo.h b/src/graph/topo.h index 3fc6681e06..986f8a09ad 100644 --- a/src/graph/topo.h +++ b/src/graph/topo.h @@ -207,6 +207,9 @@ struct ncclTopoSystem { int pivotA2ANumBiRings; bool treeDefined; bool ll128Enabled; +#ifdef ENABLE_WARP_SPEED + bool warpSpeedEnabled; +#endif float baseBw; bool mscclEnabled; diff --git a/src/include/comm.h b/src/include/comm.h index 65bb69b636..1f8d38b0ba 100644 --- a/src/include/comm.h +++ b/src/include/comm.h @@ -205,7 +205,12 @@ struct ncclTaskColl { int chunkSteps, sliceSteps; // Computed later: size_t trafficBytes; +#ifdef ENABLE_WARP_SPEED + int32_t nMaxChannels:16; + bool useWarpSpeed; +#else int32_t nMaxChannels:8; +#endif int32_t nWarps:8; int32_t algorithm:8, protocol:8, pipeline:8; uint32_t isCollnet:1, isNvls:1; @@ -550,7 +555,7 @@ struct ncclComm { float bandwidths[NCCL_NUM_FUNCTIONS][NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS]; int maxThreads[NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS]; uint64_t minMaxLLRange[RCCL_TUNABLE_COLLS][NCCL_NUM_PROTOCOLS - 1][RCCL_PROTOCOL_ENTRY_SIZE]; - uint64_t minMaxChannelThresholds[RCCL_TUNABLE_COLLS][RCCL_CHANNELS_TUNABLE_ENTRIES][3]; //for each collective, set for 5 channel-counts: 32,40,48,56,64, the two values for min/max size-threshold + uint64_t minMaxChannelThresholds[RCCL_TUNABLE_COLLS][RCCL_CHANNELS_TUNABLE_ENTRIES][3]; //for each collective, set for 5 channel-counts: 32,40,48,56,64, the two values for min/max size-threshold /* This attribute can indicate the states of communicators and return code of * asynchronous NCCL operations. */ @@ -719,7 +724,7 @@ struct ncclComm { char* archName; // multiProcessorCount from hipDeviceProp_t [RCCL] int cuCount; - + uint64_t endMagic; }; diff --git a/src/include/device.h b/src/include/device.h index 0cf756974c..91f9858dd8 100644 --- a/src/include/device.h +++ b/src/include/device.h @@ -126,8 +126,12 @@ union ncclLLFifoLine { #define NCCL_MAX_GROUPS (NCCL_MAX_NTHREADS/WARP_SIZE) #endif +#ifdef ENABLE_WARP_SPEED +#define MAXCHANNELS 512 +#else #define MAXCHANNELS 128 -#define CHANNEL_LIMIT 16 +#endif +#define CHANNEL_LIMIT 16 // this is used to limit channels for pre MI3xx GPUs #define NCCL_MAX_LOCAL_RANKS 72 #define NCCL_MIN_NTHREADS (4*WARP_SIZE) #define NCCL_SIMPLE_MAX_NTHREADS NCCL_MAX_NTHREADS @@ -354,7 +358,11 @@ inline __device__ int ncclP2pChannelToPart(int nP2pChannels, int base, int chann struct alignas(16) ncclDevWorkColl { // Running on channels [channelLo..channelHi], hi is inclusive. // nChannels == (channelHi - channelLo) + 1 +#ifdef ENABLE_WARP_SPEED + uint32_t channelLo:16, channelHi:16; +#else uint32_t channelLo:8, channelHi:8; +#endif uint32_t nWarps:8; uint32_t redOpArgIsPtr:1, regUsed:1, netRegUsed:1, oneNode:1, direct:2, isOneRPN:1, rcclUseOneSlice:1, gfx9CheapFenceOff:1; uint32_t root:30, connIndex:2; @@ -573,7 +581,7 @@ struct ncclDevComm { int p2pChunkSize; int isAllNvlink; int p2pnChannelsPerPeer; - + int warpLevelComm; int* collNetDenseToUserRank; // Flag to ask NCCL kernels to abort @@ -637,11 +645,6 @@ struct alignas(16) ncclDevKernelArgs { // struct ncclDevWorkBatch batches[]; }; -__host__ __device__ constexpr int ncclMaxKernelArgsSize(/*int cudaDriver, */int cudaArch=NCCL_CUDA_ARCH) { - //return (cudaArch < 700 || cudaDriver < 12010) ? 4<<10 : (32<<10)-4; - return 4<<10; -} - template struct alignas(16) ncclDevKernelArgsStorage { union { @@ -650,9 +653,24 @@ struct alignas(16) ncclDevKernelArgsStorage { }; }; + +typedef ncclDevKernelArgsStorage<(5<<10)> ncclDevKernelArgs5K; typedef ncclDevKernelArgsStorage<(4<<10)> ncclDevKernelArgs4K; //typedef ncclDevKernelArgsStorage<(32<<10)-4> ncclDevKernelArgs31K; +#ifdef ENABLE_WARP_SPEED +// needed extra storage for accomodating more channels than 128 for WarpSpeed support +// 256 channels (i.e. 256 warps) would hang without this extra storage +// 5KB should be sufficient for now +typedef ncclDevKernelArgs5K ncclDevKernelArgsDefaultStorage; +#else +typedef ncclDevKernelArgs4K ncclDevKernelArgsDefaultStorage; +#endif +__host__ __device__ constexpr int ncclMaxKernelArgsSize(/*int cudaDriver, */int cudaArch=NCCL_CUDA_ARCH) { + //return (cudaArch < 700 || cudaDriver < 12010) ? 4<<10 : (32<<10)-4; + return sizeof(ncclDevKernelArgsDefaultStorage); +} + template __host__ __device__ constexpr T min_constexpr(T a) { return a; } template diff --git a/src/include/rccl_common.h b/src/include/rccl_common.h index 9ae65f7df8..dd7bd1b55f 100644 --- a/src/include/rccl_common.h +++ b/src/include/rccl_common.h @@ -118,4 +118,9 @@ ncclResult_t commSetUnrollFactor(struct ncclComm* comm); bool validHsaScratchEnvSetting(const char*hsaScratchEnv, int hipRuntimeVersion, int firmwareVersion, const char* archName); int parseFirmwareVersion(); bool rcclIsArchSupportedForFunc(struct ncclTaskColl* info, char const* archName); +#ifdef ENABLE_WARP_SPEED +void rcclSetWarpSpeedCUs(struct ncclComm* comm, int algo, int threadsPerBlock, int& rcclWarpSpeedChannels); +void rcclSetWarpSpeedSupportAndFinalCuCount(struct ncclComm* comm, struct ncclKernelPlan* plan, int nChannels, int& support, int &cuCount); +void rcclSetWarpSpeedAuto(struct ncclComm* comm, struct ncclTaskColl* info, size_t nBytes); +#endif #endif diff --git a/src/init.cc b/src/init.cc index 8553ee92cc..86187bd35d 100644 --- a/src/init.cc +++ b/src/init.cc @@ -168,7 +168,7 @@ ncclResult_t checkHostUncacheMemSetting(struct ncclComm* comm) { else { return ncclSuccess; } - #endif + #endif } static void initOnceFunc() { @@ -1075,7 +1075,10 @@ NCCL_PARAM(GraphDumpFileRank, "GRAPH_DUMP_FILE_RANK", 0); NCCL_PARAM(CollNetNodeThreshold, "COLLNET_NODE_THRESHOLD", 2); NCCL_PARAM(NvbPreconnect, "NVB_PRECONNECT", 0); NCCL_PARAM(AllocP2pNetLLBuffers, "ALLOC_P2P_NET_LL_BUFFERS", 0); - +#ifdef ENABLE_WARP_SPEED +extern int64_t rcclParamWarpSpeedEnable(); +extern int64_t rcclParamWarpSpeedAutoMode(); +#endif // MNNVL: Flag to indicate whether to enable Multi-Node NVLink NCCL_PARAM(MNNVLEnable, "MNNVL_ENABLE", 2); @@ -1453,8 +1456,12 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p allGather3Data[rank].nc = 4; } } +#ifdef ENABLE_WARP_SPEED + comm->topo->warpSpeedEnabled = (rcclParamWarpSpeedEnable() != 0 || rcclParamWarpSpeedAutoMode() != 0); +#endif + // For single node communicators that do not uses the full xgmi links per gpu, i.e., nranks < 8 - // Inflate the nChannels a bit to achieve higher b/w. + // Inflate the nChannels a bit to achieve higher b/w. if (IsArchMatch(comm->topo->nodes[GPU].nodes[idx].gpu.gcn, "gfx950")) { if (nranks == 2 && nNodes == 1){ allGather3Data[rank].nc = 16; @@ -1464,7 +1471,12 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p allGather3Data[rank].nc = 4; } } - +#ifdef ENABLE_WARP_SPEED + // Double default channels for WarpSpeed enabled communicators + if (comm->topo->warpSpeedEnabled) { + allGather3Data[rank].nc *= 2; + } +#endif allGather3Data[rank].pivotA2AEnabled = comm->topo->pivotA2AEnabled && rcclParamPivotAlltoallEnable(); comm->topo->ll128Enabled = comm->topo->ll128Enabled || rcclParamLL128ForceEnable(); allGather3Data[rank].ll128Enabled = comm->topo->ll128Enabled; @@ -1817,8 +1829,8 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p // Compute time models for algorithm and protocol combinations NCCLCHECKGOTO(ncclTopoTuneModel(comm, comm->minCompCap, comm->maxCompCap, graphs), ret, fail); - INFO(NCCL_INIT, "comm:%p, nRanks:%d, nNodes:%d, coll channels:%d collnet channels:%d, nvls channels:%d, p2p channels:%d, p2p channels per peer:%d", comm, comm->nRanks, comm->nNodes, comm->nChannels, comm->nChannels, comm->nvlsChannels, comm->p2pnChannels, comm->p2pnChannelsPerPeer); - + INFO(NCCL_INIT, "comm:%p, nRanks:%d, nNodes:%d, coll channels:%d collnet channels:%d, nvls channels:%d, p2p channels:%d, p2p channels per peer:%d", comm, comm->nRanks, comm->nNodes, comm->nChannels, comm->nChannels, comm->nvlsChannels, comm->p2pnChannels, comm->p2pnChannelsPerPeer); + if (comm->intraRank == 0) { // Load ncclParamLaunchMode const char* str = ncclGetEnv("NCCL_LAUNCH_MODE"); enum ncclLaunchMode mode, modeOld; @@ -2075,10 +2087,10 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) { comm->cuCount = cuCount; NCCLCHECKGOTO(initTransportsRank(comm, job->parent, timers), res, fail); - + // Check if using host uncached mem correctly NCCLCHECK(checkHostUncacheMemSetting(comm)); - + // RCCL: determine and set unroll factor for comm NCCLCHECK(commSetUnrollFactor(comm)); diff --git a/src/rccl_wrap.cc b/src/rccl_wrap.cc index 363e39ef9e..caf8a79488 100644 --- a/src/rccl_wrap.cc +++ b/src/rccl_wrap.cc @@ -34,6 +34,14 @@ RCCL_PARAM(PipelineAllDTypes, "PIPELINE_ALL_DATA_TYPES", 0); // Otherwise, it is automatically set for certain archs, datatypes and reduction collectives RCCL_PARAM(disableReduceCopyPipelining, "DISABLE_REDUCE_COPY_PIPELINING", 0); RCCL_PARAM(DirectAllGatherThreshold, "DIRECT_ALLGATHER_THRESHOLD", 75497472); +RCCL_PARAM(ThreadsPerBlock, "THREADS_PER_BLOCK", -1); +RCCL_PARAM(UnrollFactor, "UNROLL_FACTOR", -1); +#ifdef ENABLE_WARP_SPEED +RCCL_PARAM(WarpSpeedCuCount, "WARP_SPEED_CU_COUNT", 0); +RCCL_PARAM(WarpSpeedAutoMode, "WARP_SPEED_AUTO", 0); +RCCL_PARAM(WarpSpeedEnable, "WARP_SPEED_ENABLE", 0); +#endif +#define RCCL_WARP_SPEED_MIN_BYTES (1ULL << 26) // 64 MB void rcclUpdateCollectiveProtocol(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info) { // Honor user input for protocol choice @@ -162,6 +170,10 @@ ncclResult_t rcclOverrideChannels(struct ncclComm* comm, ncclFunc_t coll, size_t } } +#ifdef ENABLE_WARP_SPEED + // fallback to max 64 channels and tune warp speed channels later + nc = std::min(nc, 64); +#endif return ncclSuccess; } @@ -295,7 +307,11 @@ ncclResult_t rcclGetAlgoInfo(struct ncclComm* comm, ncclFunc_t coll, uint64_t co NCCLCHECK(getAlgoInfo(comm, &task, collNetSupport, nvlsSupport, numPipeOps)); *algo = task.algorithm; *protocol = task.protocol; +#ifdef ENABLE_WARP_SPEED + *maxChannels = task.useWarpSpeed? task.nMaxChannels / task.nWarps : task.nMaxChannels; +#else *maxChannels = task.nMaxChannels; +#endif return ncclSuccess; } @@ -398,6 +414,88 @@ void rcclSetP2pNetChunkSize(struct ncclComm* comm, int& rcclP2pNetChunkSize) { } rcclP2pNetChunkSize = p2pNetChunkSize; } +#ifdef ENABLE_WARP_SPEED +void rcclSetWarpSpeedCUs(struct ncclComm* comm, int algo, int threadsPerBlock, int& rcclWarpSpeedChannels) { + static int userChannelControlInput = RCCL_VALUE_UNSET; + int warpsPerBlock = threadsPerBlock / comm->WarpSize; + // only adjust channels for RING algorithm + if(algo != NCCL_ALGO_RING) { + return; + } + if (userChannelControlInput == RCCL_VALUE_UNSET) { + const char *inputStr = getenv("NCCL_THREAD_THRESHOLDS"); + if (!inputStr) { + inputStr = getenv("NCCL_MAX_NCHANNELS"); + } + if (!inputStr) { + inputStr = getenv("NCCL_MIN_NCHANNELS"); + } + userChannelControlInput = !inputStr ? 0 : 1; + } + if(!userChannelControlInput && comm->topo->warpSpeedEnabled) { + if(rcclParamWarpSpeedCuCount() != 0) { + rcclWarpSpeedChannels = rcclParamWarpSpeedCuCount() * warpsPerBlock; + INFO(NCCL_INIT, "RCCL Warp CU count set to user defined %d resulting in %d channels", rcclParamWarpSpeedCuCount(), rcclWarpSpeedChannels); + return; + } + // reuse the existing channel tuning logic if possible + if (comm->nNodes == 1) { + rcclWarpSpeedChannels = rcclWarpSpeedChannels * warpsPerBlock / 2; // use 50% CUs for single node case + } else { + rcclWarpSpeedChannels = std::min(256, rcclWarpSpeedChannels * warpsPerBlock); + } + INFO(NCCL_INIT, "RCCL Warp Speed Channels set to %d", rcclWarpSpeedChannels); + } +} + +void rcclSetWarpSpeedSupportAndFinalCuCount(struct ncclComm* comm, struct ncclKernelPlan* plan, int nChannels, int& support, int &cuCount) { + if(!comm->topo->warpSpeedEnabled) { + support = 0; + cuCount = nChannels; + return; + } + // WarpSpeed is not supported currently for the following cases: + // 1. if any work batch in the plan contains P2P work + // 2. or any collective task is not using RING algorithm + bool hasP2p = !ncclIntruQueueEmpty(&plan->p2pTaskQueue); + bool hasNonRing = false; + struct ncclTaskColl* task = ncclIntruQueueHead(&plan->collTaskQueue); + while (task != nullptr) { + if (task->algorithm != NCCL_ALGO_RING || !(task->useWarpSpeed)) { + hasNonRing = true; + break; + } + task = task->next; + } + int warpsPerBlock = plan->threadPerBlock / comm->WarpSize; + support = (hasP2p || hasNonRing) ? 0 : 1; + cuCount = (support == 0)? nChannels : nChannels / warpsPerBlock + ((nChannels % warpsPerBlock) != 0 ? 1 : 0); // each CU can handle warpsPerBlock +} + +void rcclSetWarpSpeedAuto(struct ncclComm* comm, struct ncclTaskColl* info, size_t nBytes) { + info->useWarpSpeed = false; + if(!comm->topo->warpSpeedEnabled) { + return; + } + info->useWarpSpeed = (info->algorithm == NCCL_ALGO_RING); // Enabled by default for any RING algorithm when platform supports it + if(rcclParamWarpSpeedAutoMode() != 0 && IsArchMatch(comm->archName, "gfx950")) { // Auto mode only available for gfx950 currently + size_t minBytes = 0; + if(info->func == ncclFuncAllReduce || info->func == ncclFuncAllGather) minBytes = RCCL_WARP_SPEED_MIN_BYTES; + else if (info->func == ncclFuncReduceScatter) minBytes = RCCL_WARP_SPEED_MIN_BYTES << 2; // ReduceScatter requires higher message size to benefit from WarpSpeed + if(comm->nNodes == 1) { + if(nBytes >= minBytes && minBytes > 0) { + comm->unroll = NCCL_UNROLL_2; + info->nWarps = 4; + } + } else { + // TODO: set unroll factor per task rather than per comm + commSetUnrollFactor(comm); + info->useWarpSpeed = false; + } + } + +} +#endif void rcclGetMaxNthreads(struct ncclComm* comm, int maxNthreads[]) { if (IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx950")) { @@ -411,12 +509,27 @@ void rcclGetMaxNthreads(struct ncclComm* comm, int maxNthreads[]) { void rcclOptThreadBlockSize(struct ncclComm* comm, struct ncclTaskColl* info, size_t nBytes, int& nThreads) { static int maxNthreads[NCCL_NUM_PROTOCOLS] = {0}; if (maxNthreads[NCCL_PROTO_SIMPLE] == 0) rcclGetMaxNthreads(comm, maxNthreads); + if(rcclParamThreadsPerBlock() != -1) { + nThreads = rcclParamThreadsPerBlock(); + if(nThreads % comm->WarpSize != 0) { + nThreads = ((nThreads / comm->WarpSize) + 1) * comm->WarpSize; + INFO(NCCL_INIT, "RCCL Threads per block adjusted to %d to be multiple of warp size %d", nThreads, comm->WarpSize); + } + if(nThreads > maxNthreads[NCCL_PROTO_SIMPLE]) { + nThreads = maxNthreads[NCCL_PROTO_SIMPLE]; + INFO(NCCL_INIT, "RCCL Threads per block reduced to %d to match max threads", nThreads); + } else if (nThreads < 3 * comm->WarpSize) { + nThreads = 3 * comm->WarpSize; // min requirement for tree + INFO(NCCL_INIT, "RCCL Threads per block increased to %d to be at least one warp", nThreads); + } + return; + } if (info->algorithm == NCCL_ALGO_TREE) nThreads = maxNthreads[NCCL_PROTO_SIMPLE]; // Tree now uses all threads always. if (info->algorithm == NCCL_ALGO_PAT) nThreads = maxNthreads[NCCL_PROTO_SIMPLE]; if (comm->nNodes == 1) nThreads = RCCL_SINGLE_NODE_MAX_NTHREADS; // For single node, we use half the number of threads for perf reasons. // The following should be already set correctly by getNthreads // but need to override the changes for TREE and PAT in the previous lines - if (info->protocol == NCCL_PROTO_LL) nThreads = maxNthreads[NCCL_PROTO_LL]; + else if (info->protocol == NCCL_PROTO_LL) nThreads = maxNthreads[NCCL_PROTO_LL]; // ReduceScatter small count optimization if (info->func == ncclFuncReduceScatter && divUp(nBytes, comm->nRanks) <= 524288) nThreads = maxNthreads[NCCL_PROTO_LL]; } @@ -436,6 +549,15 @@ ncclResult_t rcclFuncMaxSendRecvCount(ncclFunc_t func, int nRanks, size_t count, } ncclResult_t commSetUnrollFactor(struct ncclComm* comm) { + if( rcclParamUnrollFactor() != -1 ) { + comm->unroll = rcclParamUnrollFactor(); //-1 to map to 0 based indexing + if(comm->unroll < NCCL_UNROLL_1 || comm->unroll >= NCCL_NUM_UNROLLS) { + WARN("Invalid RCCL_UNROLL_FACTOR %d specified. Valid values are 0 to 2 corresponding to unroll factors of 1, 2, and 4 respectively.", comm->unroll); + return ncclInvalidArgument; + } + INFO(NCCL_INIT, "RCCL Unroll Factor (user set): %d", (int) (pow(2.0, (double)comm->unroll))); + return ncclSuccess; + } if(IsArchMatch(comm->archName, "gfx950")) { if(comm->nNodes == 1) comm->unroll = NCCL_UNROLL_1;