From 7bbce085ccd52219043e7ba702720fdf3bbafa0a Mon Sep 17 00:00:00 2001 From: Wenkai Du <43822138+wenkaidu@users.noreply.github.com> Date: Thu, 8 Sep 2022 14:45:27 -0700 Subject: [PATCH] Enable LL128 protocol support (#605) * Enable LL128 protocol support * Use shared memory object directly when possible --- src/collectives/device/all_reduce.h | 4 - src/collectives/device/common.h | 318 +++++++++++++++++---------- src/collectives/device/op128.h | 19 -- src/collectives/device/prims_ll128.h | 176 +++++++++------ src/enqueue.cc | 44 ++-- src/graph/rome_models.cc | 2 + src/graph/topo.h | 5 + src/graph/tuning.cc | 27 ++- src/include/collectives.h | 14 +- src/include/comm.h | 3 +- src/include/devcomm.h | 6 +- src/include/enqueue.h | 1 + src/init.cc | 10 +- 13 files changed, 382 insertions(+), 247 deletions(-) diff --git a/src/collectives/device/all_reduce.h b/src/collectives/device/all_reduce.h index 79aa3fbfa8..809177f41d 100644 --- a/src/collectives/device/all_reduce.h +++ b/src/collectives/device/all_reduce.h @@ -574,7 +574,6 @@ template struct RunWorkElement { __device__ __attribute__((noinline)) void run(ncclWorkElem *args) { using Proto = ProtoSimple; - if (threadIdx.x == 0) __insert_timestamp(__LINE__); runRing(args); } }; @@ -582,7 +581,6 @@ struct RunWorkElement struct RunWorkElement { __device__ __attribute__((noinline)) void run(ncclWorkElem *args) { - if (threadIdx.x == 0) __insert_timestamp(__LINE__); runTreeUpDown>(args); } }; @@ -688,7 +686,6 @@ struct RunWorkElement struct RunWorkElement { __device__ __attribute__((noinline)) void run(ncclWorkElem *args) { - if (threadIdx.x == 0) __insert_timestamp(__LINE__); runRing(args); } }; @@ -696,7 +693,6 @@ struct RunWorkElement struct RunWorkElement { __device__ __attribute__((noinline)) void run(ncclWorkElem *args) { - if (threadIdx.x == 0) __insert_timestamp(__LINE__); if (args->pad_0 == 0) runTreeUpDown(args); else runTreeSplit(args); } diff --git a/src/collectives/device/common.h b/src/collectives/device/common.h index 343c4010d6..05eef2c05b 100644 --- a/src/collectives/device/common.h +++ b/src/collectives/device/common.h @@ -102,26 +102,112 @@ static const __device__ constexpr ncclKernelFunc_t ncclFuncs[]{ #endif }; -template +#define NCCL_FUNC5_LL128(func, algo, devredop, type, nullify) \ + MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL, devredop, type)), \ + MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL128, devredop, type)), \ + MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, SIMPLE, devredop, type)) + +#define NCCL_FUNC4_LL128(func, devredop, type, nullify) \ + NCCL_FUNC5_LL128(func, TREE, devredop, type, nullify), \ + NCCL_FUNC5_LL128(func, RING, devredop, type, nullify), \ + NCCL_FUNC5_LL128(func, COLLNET, devredop, type, nullify) + +// Must be consistent with ncclDataType_t +#define NCCL_FUNCS3A_LL128(func, devredop, nullForFloat) \ + NCCL_FUNC4_LL128(func, devredop, int8_t, 0), \ + NCCL_FUNC4_LL128(func, devredop, uint8_t, 0), \ + NCCL_FUNC4_LL128(func, devredop, int32_t, 0), \ + NCCL_FUNC4_LL128(func, devredop, uint32_t, 0), \ + NCCL_FUNC4_LL128(func, devredop, int64_t, 0), \ + NCCL_FUNC4_LL128(func, devredop, uint64_t, 0), \ + NCCL_FUNC4_LL128(func, devredop, half, nullForFloat), \ + NCCL_FUNC4_LL128(func, devredop, float, nullForFloat), \ + NCCL_FUNC4_LL128(func, devredop, double, nullForFloat), \ + NCCL_FUNC4_LL128(func, devredop, rccl_bfloat16, nullForFloat) +#define NCCL_FUNCS3B_LL128(func, devredop) \ + NCCL_FUNC4_LL128(func, devredop, int8_t, 0), \ + NCCL_FUNC4_LL128(func, devredop, int8_t, 0), \ + NCCL_FUNC4_LL128(func, devredop, int8_t, 0), \ + NCCL_FUNC4_LL128(func, devredop, int8_t, 0), \ + NCCL_FUNC4_LL128(func, devredop, int8_t, 0), \ + NCCL_FUNC4_LL128(func, devredop, int8_t, 0), \ + NCCL_FUNC4_LL128(func, devredop, int8_t, 0), \ + NCCL_FUNC4_LL128(func, devredop, int8_t, 0), \ + NCCL_FUNC4_LL128(func, devredop, int8_t, 0), \ + NCCL_FUNC4_LL128(func, devredop, int8_t, 0) + +// Must be consistent with ncclRedOp_t +#define NCCL_FUNCS2A_LL128(func) \ + NCCL_FUNCS3A_LL128(func, Sum, /*nullForFloat=*/0), \ + NCCL_FUNCS3A_LL128(func, Prod, /*nullForFloat=*/0), \ + NCCL_FUNCS3A_LL128(func, Max, /*nullForFloat=*/0), \ + NCCL_FUNCS3A_LL128(func, Min, /*nullForFloat=*/0), \ + NCCL_FUNCS3A_LL128(func, PreMulSum, /*nullForFloat=*/0), \ + NCCL_FUNCS3A_LL128(func, SumPostDiv, /*nullForFloat=*/1) + +#define NCCL_FUNCS2B_LL128(func) \ + NCCL_FUNCS3B_LL128(func, Sum), \ + NCCL_FUNCS3B_LL128(func, Sum), \ + NCCL_FUNCS3B_LL128(func, Sum), \ + NCCL_FUNCS3B_LL128(func, Sum), \ + NCCL_FUNCS3B_LL128(func, Sum), \ + NCCL_FUNCS3B_LL128(func, Sum) + +// Must be consistent with the ncclFuncSet enum +using ncclKernelFunc_t = void (*)(); + +static const __device__ constexpr ncclKernelFunc_t ncclFuncs_ll128[]{ +// Don't try to initialize the host shadow copy of this device-side global +// variable. There is no host pointer to a device-side function, which +// confuses clang. This will be fixed in the next clang release. +#if defined(__HIP_DEVICE_COMPILE__) +#if defined(BUILD_ALLREDUCE_ONLY) + NCCL_FUNC4_LL128(AllReduce, Sum, float, 0), +#else + NCCL_FUNCS2B_LL128(Broadcast), + NCCL_FUNCS2A_LL128(Reduce), + NCCL_FUNCS2B_LL128(AllGather), + NCCL_FUNCS2A_LL128(ReduceScatter), + NCCL_FUNCS2A_LL128(AllReduce), + NCCL_ONERANK_REDUCE_NAME(PreMulSum, int8_t), + NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint8_t), + NCCL_ONERANK_REDUCE_NAME(PreMulSum, int32_t), + NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint32_t), + NCCL_ONERANK_REDUCE_NAME(PreMulSum, int64_t), + NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint64_t), + NCCL_ONERANK_REDUCE_NAME(PreMulSum, half), + NCCL_ONERANK_REDUCE_NAME(PreMulSum, float), + NCCL_ONERANK_REDUCE_NAME(PreMulSum, double), +#if defined(RCCL_BFLOAT16) + NCCL_ONERANK_REDUCE_NAME(PreMulSum, rccl_bfloat16), +#endif + NCCL_FUNC_NAME(SendRecv, RING, SIMPLE, Sum, int8_t), + NCCL_FUNC_NAME(AllToAllPivot, RING, SIMPLE, Sum, int8_t), +#endif +#endif +}; + +template struct Caller { static __device__ __host__ void call(unsigned short funcIndex) noexcept { constexpr unsigned short m = f + (l - f) / 2; - return (funcIndex < m) ? Caller::call(funcIndex) : Caller::call(funcIndex); + return (funcIndex < m) ? Caller::call(funcIndex) : Caller::call(funcIndex); } }; -template -struct Caller{ +template +struct Caller{ static __device__ __host__ - void call(unsigned short funcIndex) noexcept { ncclFuncs[f](); } + void call(unsigned short funcIndex) noexcept { if (u) ncclFuncs_ll128[f](); else ncclFuncs[f](); } }; static_assert(FUNC_INDEX_P2P == 2710, "Wrong P2P function index"); static_assert(FUNC_INDEX_ALLTOALL_PIVOT == 2711, "Wrong AllToAllPivot function index"); +template inline __device__ void NCCL_CALL_FUNCTIONS(unsigned short funcIndex) noexcept { @@ -130,43 +216,59 @@ void NCCL_CALL_FUNCTIONS(unsigned short funcIndex) noexcept { ncclFunction_AllReduce_RING_SIMPLE_Sum_float(); else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_LL)) ncclFunction_AllReduce_RING_LL_Sum_float(); - else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_LL128)) + else if (USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_LL128)) + ncclFunction_AllReduce_RING_LL128_Sum_float(); + else if (!USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_LL128)) ncclFunction_AllReduce_RING_LL_Sum_float(); else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_TREE, NCCL_PROTO_SIMPLE)) ncclFunction_AllReduce_TREE_SIMPLE_Sum_float(); else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_TREE, NCCL_PROTO_LL)) ncclFunction_AllReduce_TREE_LL_Sum_float(); + else if (USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_TREE, NCCL_PROTO_LL128)) + ncclFunction_AllReduce_TREE_LL128_Sum_float(); + else if (!USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_TREE, NCCL_PROTO_LL128)) + ncclFunction_AllReduce_TREE_LL_Sum_float(); else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET, NCCL_PROTO_SIMPLE)) ncclFunction_AllReduce_COLLNET_SIMPLE_Sum_float(); else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET, NCCL_PROTO_LL)) ncclFunction_AllReduce_COLLNET_LL_Sum_float(); + else if (USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET, NCCL_PROTO_LL128)) + ncclFunction_AllReduce_COLLNET_LL128_Sum_float(); + else if (!USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET, NCCL_PROTO_LL128)) + ncclFunction_AllReduce_COLLNET_LL_Sum_float(); else assert("Unsupported function index"); #else if (funcIndex < 540) { if (funcIndex % 9 == 0) ncclFunction_Broadcast_TREE_LL_Sum_int8_t(); - else if (funcIndex % 9 == 1) ncclFunction_Broadcast_TREE_LL_Sum_int8_t(); + else if (USING_LL128 && funcIndex % 9 == 1) ncclFunction_Broadcast_TREE_LL128_Sum_int8_t(); + else if (!USING_LL128 && funcIndex % 9 == 1) ncclFunction_Broadcast_TREE_LL_Sum_int8_t(); else if (funcIndex % 9 == 2) ncclFunction_Broadcast_TREE_SIMPLE_Sum_int8_t(); else if (funcIndex % 9 == 3) ncclFunction_Broadcast_RING_LL_Sum_int8_t(); - else if (funcIndex % 9 == 4) ncclFunction_Broadcast_RING_LL_Sum_int8_t(); + else if (USING_LL128 && funcIndex % 9 == 4) ncclFunction_Broadcast_RING_LL128_Sum_int8_t(); + else if (!USING_LL128 && funcIndex % 9 == 4) ncclFunction_Broadcast_RING_LL_Sum_int8_t(); else if (funcIndex % 9 == 5) ncclFunction_Broadcast_RING_SIMPLE_Sum_int8_t(); else if (funcIndex % 9 == 6) ncclFunction_Broadcast_COLLNET_LL_Sum_int8_t(); - else if (funcIndex % 9 == 7) ncclFunction_Broadcast_COLLNET_LL_Sum_int8_t(); + else if (USING_LL128 && funcIndex % 9 == 7) ncclFunction_Broadcast_COLLNET_LL128_Sum_int8_t(); + else if (!USING_LL128 && funcIndex % 9 == 7) ncclFunction_Broadcast_COLLNET_LL_Sum_int8_t(); else ncclFunction_Broadcast_COLLNET_SIMPLE_Sum_int8_t(); } - else if (funcIndex < 1080) Caller<540, 1080>::call(funcIndex); + else if (funcIndex < 1080) Caller<540, 1080, USING_LL128>::call(funcIndex); else if (funcIndex < 1620) { if (funcIndex % 9 == 0) ncclFunction_AllGather_TREE_LL_Sum_int8_t(); - else if (funcIndex % 9 == 1) ncclFunction_AllGather_TREE_LL_Sum_int8_t(); + else if (USING_LL128 && funcIndex % 9 == 1) ncclFunction_AllGather_TREE_LL128_Sum_int8_t(); + else if (!USING_LL128 && funcIndex % 9 == 1) ncclFunction_AllGather_TREE_LL_Sum_int8_t(); else if (funcIndex % 9 == 2) ncclFunction_AllGather_TREE_SIMPLE_Sum_int8_t(); else if (funcIndex % 9 == 3) ncclFunction_AllGather_RING_LL_Sum_int8_t(); - else if (funcIndex % 9 == 4) ncclFunction_AllGather_RING_LL_Sum_int8_t(); + else if (USING_LL128 && funcIndex % 9 == 4) ncclFunction_AllGather_RING_LL128_Sum_int8_t(); + else if (!USING_LL128 && funcIndex % 9 == 4) ncclFunction_AllGather_RING_LL_Sum_int8_t(); else if (funcIndex % 9 == 5) ncclFunction_AllGather_RING_SIMPLE_Sum_int8_t(); else if (funcIndex % 9 == 6) ncclFunction_AllGather_COLLNET_LL_Sum_int8_t(); - else if (funcIndex % 9 == 7) ncclFunction_AllGather_COLLNET_LL_Sum_int8_t(); + else if (USING_LL128 && funcIndex % 9 == 7) ncclFunction_AllGather_COLLNET_LL128_Sum_int8_t(); + else if (!USING_LL128 && funcIndex % 9 == 7) ncclFunction_AllGather_COLLNET_LL_Sum_int8_t(); else ncclFunction_AllGather_COLLNET_SIMPLE_Sum_int8_t(); } - else if (funcIndex < 2700) Caller<1620, 2700>::call(funcIndex); + else if (funcIndex < 2700) Caller<1620, 2700, USING_LL128>::call(funcIndex); else { switch (funcIndex - 2700) { case 0: @@ -219,52 +321,52 @@ class ncclFunction { #ifdef ENABLE_COLLTRACE #define traceColl(elem,launch_type) \ - uint32_t pos = __atomic_fetch_add(ncclShmem->comm.collTraceTail, 1, __ATOMIC_SEQ_CST)%COLLTRACE_NUM_ITEMS; \ - ncclShmem->comm.collTrace[pos].timeStamp = __builtin_amdgcn_s_memrealtime(); \ - ncclShmem->comm.collTrace[pos].bid = blockIdx.x; \ - ncclShmem->comm.collTrace[pos].funcIndex = ncclShmem->work.header.funcIndex; \ - asm volatile ("s_getreg_b32 %0, hwreg(HW_REG_HW_ID)" : "=s" (ncclShmem->comm.collTrace[pos].data_0)); \ + uint32_t pos = __atomic_fetch_add(shmem.comm.collTraceTail, 1, __ATOMIC_SEQ_CST)%COLLTRACE_NUM_ITEMS; \ + shmem.comm.collTrace[pos].timeStamp = __builtin_amdgcn_s_memrealtime(); \ + shmem.comm.collTrace[pos].bid = blockIdx.x; \ + shmem.comm.collTrace[pos].funcIndex = shmem.work.header.funcIndex; \ + asm volatile ("s_getreg_b32 %0, hwreg(HW_REG_HW_ID)" : "=s" (shmem.comm.collTrace[pos].data_0)); \ if (elem.header.type == ncclWorkTypeP2p) { \ struct ncclWorkElemP2p *p2pElems = (struct ncclWorkElemP2p *)&elem; \ - ncclShmem->comm.collTrace[pos].p2p[0].connIndex = p2pElems[0].connIndex; \ - ncclShmem->comm.collTrace[pos].p2pOpCount[0] = p2pElems[0].opCount; \ - ncclShmem->comm.collTrace[pos].p2p[0].ngroups = p2pElems[0].ngroups; \ - ncclShmem->comm.collTrace[pos].p2p[0].nWarps = p2pElems[0].nWarps; \ - ncclShmem->comm.collTrace[pos].p2p[0].warpStart = p2pElems[0].warpStart; \ - ncclShmem->comm.collTrace[pos].p2p[0].peer = (uint16_t)(p2pElems[0].peer); \ - ncclShmem->comm.collTrace[pos].p2p[1].connIndex = p2pElems[1].connIndex; \ - ncclShmem->comm.collTrace[pos].p2pOpCount[1] = p2pElems[1].opCount; \ - ncclShmem->comm.collTrace[pos].p2p[1].ngroups = p2pElems[1].ngroups; \ - ncclShmem->comm.collTrace[pos].p2p[1].nWarps = p2pElems[1].nWarps; \ - ncclShmem->comm.collTrace[pos].p2p[1].warpStart = p2pElems[1].warpStart; \ - ncclShmem->comm.collTrace[pos].p2p[1].peer = (uint16_t)(p2pElems[1].peer); \ - ncclShmem->comm.collTrace[pos].type = (ncclCollTraceP2pElemType|launch_type); \ + shmem.comm.collTrace[pos].p2p[0].connIndex = p2pElems[0].connIndex; \ + shmem.comm.collTrace[pos].p2pOpCount[0] = p2pElems[0].opCount; \ + shmem.comm.collTrace[pos].p2p[0].ngroups = p2pElems[0].ngroups; \ + shmem.comm.collTrace[pos].p2p[0].nWarps = p2pElems[0].nWarps; \ + shmem.comm.collTrace[pos].p2p[0].warpStart = p2pElems[0].warpStart; \ + shmem.comm.collTrace[pos].p2p[0].peer = (uint16_t)(p2pElems[0].peer); \ + shmem.comm.collTrace[pos].p2p[1].connIndex = p2pElems[1].connIndex; \ + shmem.comm.collTrace[pos].p2pOpCount[1] = p2pElems[1].opCount; \ + shmem.comm.collTrace[pos].p2p[1].ngroups = p2pElems[1].ngroups; \ + shmem.comm.collTrace[pos].p2p[1].nWarps = p2pElems[1].nWarps; \ + shmem.comm.collTrace[pos].p2p[1].warpStart = p2pElems[1].warpStart; \ + shmem.comm.collTrace[pos].p2p[1].peer = (uint16_t)(p2pElems[1].peer); \ + shmem.comm.collTrace[pos].type = (ncclCollTraceP2pElemType|launch_type); \ } else { \ - ncclShmem->comm.collTrace[pos].opCount = elem.opCount; \ - ncclShmem->comm.collTrace[pos].coll.nWarps = elem.header.nWarps; \ - ncclShmem->comm.collTrace[pos].coll.bid = elem.bid; \ - ncclShmem->comm.collTrace[pos].coll.nChannels = elem.nChannels; \ - ncclShmem->comm.collTrace[pos].type = (ncclCollTraceCollElemType|launch_type); \ + shmem.comm.collTrace[pos].opCount = elem.opCount; \ + shmem.comm.collTrace[pos].coll.nWarps = elem.header.nWarps; \ + shmem.comm.collTrace[pos].coll.bid = elem.bid; \ + shmem.comm.collTrace[pos].coll.nChannels = elem.nChannels; \ + shmem.comm.collTrace[pos].type = (ncclCollTraceCollElemType|launch_type); \ } #define traceKernelLaunch(elem,firstLaunch) { \ traceColl(elem,(firstLaunch?ncclCollTraceKernelLaunchType:ncclCollTraceCollLaunchType)); \ } #define traceKernelEnd() { \ - uint32_t pos = __atomic_fetch_add(ncclShmem->comm.collTraceTail, 1, __ATOMIC_SEQ_CST)%COLLTRACE_NUM_ITEMS; \ - ncclShmem->comm.collTrace[pos].timeStamp = __builtin_amdgcn_s_memrealtime(); \ - ncclShmem->comm.collTrace[pos].bid = bid; \ - ncclShmem->comm.collTrace[pos].type = ncclCollTraceKernelEndType; \ + uint32_t pos = __atomic_fetch_add(shmem.comm.collTraceTail, 1, __ATOMIC_SEQ_CST)%COLLTRACE_NUM_ITEMS; \ + shmem.comm.collTrace[pos].timeStamp = __builtin_amdgcn_s_memrealtime(); \ + shmem.comm.collTrace[pos].bid = bid; \ + shmem.comm.collTrace[pos].type = ncclCollTraceKernelEndType; \ } #define traceAbort() { \ - uint32_t pos = __atomic_fetch_add(ncclShmem->comm.collTraceTail, 1, __ATOMIC_SEQ_CST)%COLLTRACE_NUM_ITEMS; \ - ncclShmem->comm.collTrace[pos].timeStamp = __builtin_amdgcn_s_memrealtime(); \ - ncclShmem->comm.collTrace[pos].bid = bid; \ - ncclShmem->comm.collTrace[pos].type = ncclCollTraceAbortType; \ + uint32_t pos = __atomic_fetch_add(shmem.comm.collTraceTail, 1, __ATOMIC_SEQ_CST)%COLLTRACE_NUM_ITEMS; \ + shmem.comm.collTrace[pos].timeStamp = __builtin_amdgcn_s_memrealtime(); \ + shmem.comm.collTrace[pos].bid = bid; \ + shmem.comm.collTrace[pos].type = ncclCollTraceAbortType; \ } // traceData(int16_t data2, uint32_t data4, uint64_t data8_0, uint64_t data8_1) #define traceData(data2, data4, data8_0, data8_1) { \ - uint32_t pos = __atomic_fetch_add(ncclShmem->comm.collTraceTail, 1, __ATOMIC_SEQ_CST)%COLLTRACE_NUM_ITEMS; \ + uint32_t pos = atomicAdd(ncclShmem->comm.collTraceTail, 1)%COLLTRACE_NUM_ITEMS; \ ncclShmem->comm.collTrace[pos].bid = blockIdx.x; \ ncclShmem->comm.collTrace[pos].timeStamp = __builtin_amdgcn_s_memrealtime(); \ ncclShmem->comm.collTrace[pos].funcIndex = data2; \ @@ -281,22 +383,16 @@ class ncclFunction { #ifdef ENABLE_PROFILING #define __insert_timestamp(line_num) do { \ - if (ncclShmem->prof.count < PROFILE_NUM_ITEMS) { \ - ncclShmem->prof.elem[ncclShmem->prof.count].line = line_num; \ - ncclShmem->prof.elem[ncclShmem->prof.count].timeStamp = __builtin_amdgcn_s_memrealtime(); \ - ncclShmem->prof.count++; \ + if (shmem.prof.count < PROFILE_NUM_ITEMS) { \ + shmem.prof.elem[shmem.prof.count].line = line_num; \ + shmem.prof.elem[shmem.prof.count].timeStamp = __builtin_amdgcn_s_memrealtime(); \ + shmem.prof.count++; \ } \ } while(0); #else #define __insert_timestamp(line_num) #endif -__device__ inline bool barrierReduceAny(int bit, uint32_t* abortCount) { - if (bit) atomicAdd(abortCount, 1); - __syncthreads(); - return atomicAdd(abortCount, 0) != 0; -} - // Copy src to dst and fill extra size with zeroes template __device__ void copyToShmem(Tdst *dst, Tsrc const *src, int tid, int nthreads) { @@ -385,19 +481,18 @@ struct ncclShmemGroup { struct ncclShmemData { union { - uint64_t ll128warp[NCCL_MAX_GROUPS][NCCL_MAX_GROUPS]; struct ncclShmemGroup groups[NCCL_MAX_GROUPS]; }; - uint32_t sync[NCCL_MAX_GROUPS]; uint64_t redOpArgs[NCCL_MAX_DIRECT_ARITY+1]; struct ncclDevComm comm; struct ncclChannel channel; - uint64_t pad; + uint64_t pad[2]; struct ncclWork work; #ifdef ENABLE_PROFILING struct ncclProf prof; #endif }; +static_assert(offsetof(struct ncclShmemData, work)%16 == 0, "shmem.work needs to be 16B aligned"); static __device__ void ncclRedopPtrDeref(struct ncclWorkElem* we) { if (we->header.type != ncclWorkTypeUnused && we->redOpArgIsPtr) { @@ -422,16 +517,14 @@ static __device__ void ncclRedopPtrDeref(struct ncclWorkElem* we) { extern __device__ struct ncclShmemData *ncclShmem; -template -__device__ void ncclKernel(struct ncclDevComm* comm, ncclWorkElem first) { +template +__device__ void ncclKernel(struct ncclDevComm* comm) { int tid = threadIdx.x; int nthreads = blockDim.x; int bid = blockIdx.x; __shared__ struct ncclShmemData shmem; ncclShmem = &shmem; - __shared__ uint32_t abortCount; if (tid == 0) { - abortCount = 0; for (auto i = 0; i < NCCL_MAX_GROUPS; i++) { shmem.groups[i].barrier = 0; for (auto j = 0; j < NCCL_MAX_GROUPS; j++) shmem.groups[i].barrier_next[j] = 0; @@ -439,48 +532,38 @@ __device__ void ncclKernel(struct ncclDevComm* comm, ncclWorkElem first) { } __syncthreads(); - int turn = copyToShmem(&ncclShmem->comm, comm); + int turn = copyToShmem(&shmem.comm, comm); #ifdef ENABLE_PROFILING if (tid == 0) { - ncclShmem->prof.count = 0; - ncclShmem->prof.seq = ncclShmem->comm.devProf[bid].seq; + shmem.prof.count = 0; + shmem.prof.seq = shmem.comm.devProf[bid].seq; } #endif if (tid == 0) __insert_timestamp(__LINE__); // get address of channel without incurring indirect load from ncclDevCom::channels ncclChannel *channel = &((ncclDevCommAndChannels*)comm)->channels[bid]; - turn = copyToShmem(&ncclShmem->channel, channel, turn); + turn = copyToShmem(&shmem.channel, channel, turn); - // To optimize for latency, (only) the first operation is passed as argument. - if (bid == 0 && first.header.type != ncclWorkTypeUnused) { - // Copy first elem to work and zero out the rest - copyToShmem(&ncclShmem->work, &first, tid, nthreads); - } __syncthreads(); // publish ncclShmem if (tid == 0) __insert_timestamp(__LINE__); if (tid == 0) __insert_timestamp(__LINE__); - ncclWork *workFifoHost = ncclShmem->channel.workFifo; - ncclWork *workFifoDev = ncclShmem->channel.workFifoDev; - int workFifoIx = ncclShmem->channel.index; - - bool skipLoadWork = false, firstLaunch = true; - if (bid == 0 && first.header.type != ncclWorkTypeUnused) - skipLoadWork = true; + ncclWork *workFifoHost = shmem.channel.workFifo; + ncclWork *workFifoDev = shmem.channel.workFifoDev; + int workFifoIx = shmem.channel.index; + bool firstLaunch = true; while (true) { - if (!skipLoadWork) { - copyToShmem(&ncclShmem->work, &workFifoDev[workFifoIx], tid, nthreads); - if (tid == 0) __insert_timestamp(__LINE__); - { // Check whether the last operation was aborted and make sure all threads exit - int aborted = tid == 0 ? *comm->abortFlag : 0; - if (barrierReduceAny(aborted, &abortCount)) { // publish ncclShmem->work - if (COLLTRACE && tid == 0) traceAbort(); - break; - } - if (tid == 0) - workFifoHost[workFifoIx].header.type = ncclWorkTypeUnused; + copyToShmem(&shmem.work, &workFifoDev[workFifoIx], tid, nthreads); + if (tid == 0) __insert_timestamp(__LINE__); + { // Check whether the last operation was aborted and make sure all threads exit + int aborted = tid == 0 ? *comm->abortFlag : 0; + if (__any(aborted)) { // publish shmem.work + if (COLLTRACE && tid == 0) traceAbort(); + break; } + if (tid == 0) + workFifoHost[workFifoIx].header.type = ncclWorkTypeUnused; } if (tid == 0) __insert_timestamp(__LINE__); @@ -489,47 +572,59 @@ __device__ void ncclKernel(struct ncclDevComm* comm, ncclWorkElem first) { channel->index = workFifoIx; // write back to real channel, not shmem shadow __syncwarp(); - if (ncclShmem->work.header.type == ncclWorkTypeColl) { - if (tid < NCCL_MAX_WORK_ELEMENTS) ncclRedopPtrDeref(&ncclShmem->work.elems[tid]); - } else if (ncclShmem->work.header.type == ncclWorkTypeRegColl) { - if (tid < NCCL_MAX_WORK_ELEMENTS_REG) ncclRedopPtrDeref(&ncclShmem->work.regElems[tid].elem); + if (shmem.work.header.type == ncclWorkTypeColl) { + if (tid < NCCL_MAX_WORK_ELEMENTS) ncclRedopPtrDeref(&shmem.work.elems[tid]); + } else if (shmem.work.header.type == ncclWorkTypeRegColl) { + if (tid < NCCL_MAX_WORK_ELEMENTS_REG) ncclRedopPtrDeref(&shmem.work.regElems[tid].elem); } __syncthreads(); if (COLLTRACE && tid == 0) { - traceKernelLaunch(ncclShmem->work.elems[0],firstLaunch); + traceKernelLaunch(shmem.work.elems[0],firstLaunch); firstLaunch = false; #pragma unroll 1 - for(int e=1; e < NCCL_MAX_WORK_ELEMENTS && ncclShmem->work.elems[e].header.type != ncclWorkTypeUnused; e ++) { - traceColl(ncclShmem->work.elems[e], 0); + for(int e=1; e < NCCL_MAX_WORK_ELEMENTS && shmem.work.elems[e].header.type != ncclWorkTypeUnused; e ++) { + traceColl(shmem.work.elems[e], 0); } } if (tid == 0) __insert_timestamp(__LINE__); - if (ncclShmem->work.header.funcIndex == FnIndex) - RunWork().run(&ncclShmem->work); + if (shmem.work.header.funcIndex == FnIndex) + RunWork().run(&shmem.work); else - NCCL_CALL_FUNCTIONS(ncclShmem->work.header.funcIndex); + NCCL_CALL_FUNCTIONS(shmem.work.header.funcIndex); - if (ncclShmem->work.header.isLast) break; + if (shmem.work.header.isLast) break; __syncthreads(); - skipLoadWork = false; } if (COLLTRACE && tid == 0) traceKernelEnd() #ifdef ENABLE_PROFILING - if (ncclShmem->comm.devProf->seq < PROFILE_NUM_LAUNCHES) { - copyToShmem(ncclShmem->comm.devProf+MAXCHANNELS*ncclShmem->prof.seq+blockIdx.x, &ncclShmem->prof); - if (tid == 0) ncclShmem->comm.devProf[bid].seq++; + if (shmem.comm.devProf->seq < PROFILE_NUM_LAUNCHES) { + __syncthreads(); + copyToShmem(shmem.comm.devProf+MAXCHANNELS*shmem.prof.seq+blockIdx.x, &shmem.prof); + if (tid == 0) shmem.comm.devProf[bid].seq++; } #endif } #define IMPL_COLL_KERN(func, algo, proto, devredop, type, fIndex) \ __launch_bounds__(NCCL_MAX_NTHREADS, 1) \ -__global__ void NCCL_KERN_NAME(func, algo, proto, devredop, type)(struct ncclDevComm* comm, ncclWorkElem first) { \ - if (comm->collTraceThread) \ - ncclKernel, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, true>(comm, first); \ - else \ - ncclKernel, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, false>(comm, first); \ +__global__ void NCCL_KERN_NAME(func, algo, proto, devredop, type)(struct ncclDevComm* comm) { \ + ncclKernel, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, false, false>(comm); \ +} \ + \ +__launch_bounds__(NCCL_MAX_NTHREADS, 1) \ +__global__ void NCCL_KERN_NAME_DEBUG(func, algo, proto, devredop, type)(struct ncclDevComm* comm) { \ + ncclKernel, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, true, false>(comm); \ +} \ + \ +__launch_bounds__(NCCL_MAX_NTHREADS, 1) \ +__global__ void NCCL_KERN_NAME_LL128(func, algo, proto, devredop, type)(struct ncclDevComm* comm) { \ + ncclKernel, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, false, true>(comm); \ +} \ + \ +__launch_bounds__(NCCL_MAX_NTHREADS, 1) \ +__global__ void NCCL_KERN_NAME_LL128_DEBUG(func, algo, proto, devredop, type)(struct ncclDevComm* comm) { \ + ncclKernel, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, true, true>(comm); \ } // Examples : AllReduce, RING, LL, Sum, uint8 @@ -542,7 +637,8 @@ __device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, dev // Only generate inline kernels for LL #define IMPL_COLL4(func, algo, devredop, type, ncclType) \ IMPL_COLL_FUNC(func, algo, LL, devredop, type) \ - IMPL_COLL_FUNC(func, algo, SIMPLE, devredop, type) \ + IMPL_COLL_FUNC(func, algo, LL128, devredop, type) \ + IMPL_COLL_FUNC(func, algo, SIMPLE, devredop, type) #define IMPL_COLL3(func, devredop, type, ncclType) \ IMPL_COLL4(func, TREE, devredop, type, ncclType) \ diff --git a/src/collectives/device/op128.h b/src/collectives/device/op128.h index f65117941a..88f5b917e3 100644 --- a/src/collectives/device/op128.h +++ b/src/collectives/device/op128.h @@ -9,25 +9,6 @@ #define OP128_H_ #if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__) -inline __device__ uint64_t* shmemCvtPtr(volatile uint64_t* shmemGenericPtr) { - return 0; -} - -inline __device__ void load128(const uint64_t* ptr, uint64_t &v0, uint64_t &v1) { -} - -inline __device__ void store128(uint64_t* ptr, uint64_t v0, uint64_t v1) { -} - -inline __device__ void loadShmem128(uint64_t* shmemAsmPtr, uint64_t &v0, uint64_t &v1) { -} - -inline __device__ void storeShmem128(uint64_t* shmemAsmPtr, uint64_t v0, uint64_t v1) { -} - -template -inline __device__ void loadShmemMisaligned128(T *ptr, uint64_t &v0, uint64_t &v1) { -} #else inline __device__ void load128(const uint64_t* ptr, uint64_t &v0, uint64_t &v1) { asm volatile("ld.volatile.global.v2.u64 {%0,%1}, [%2];" diff --git a/src/collectives/device/prims_ll128.h b/src/collectives/device/prims_ll128.h index 972ce9d091..bd484a473e 100644 --- a/src/collectives/device/prims_ll128.h +++ b/src/collectives/device/prims_ll128.h @@ -9,7 +9,6 @@ #define NCCL_LL128_FLAGTHREAD (NCCL_LL128_LINEELEMS-1) -#define __any_sync(WARP_MASK, needReload) (true) template class Primitives: @@ -51,20 +50,25 @@ class Primitives: inline __device__ uint64_t recvFlag(int i) { return recvStep[i]+1; } inline __device__ uint64_t sendFlag(int i) { return sendStep[i]+1; } + uint64_t* barriers; + uint64_t* barrier_next; + inline __device__ void barrier() { #if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__) - __syncthreads(); + if (nthreads != WARP_SIZE) + barrier_by_group(); #else asm volatile ("bar.sync %1, %0;" :: "r"(nthreads), "r"(15-group)); #endif } uint32_t abort = 0; + uint32_t* sync; inline __device__ int checkAbort(int &spins, int i, int send) { spins++; if (abort == 0 && spins == NCCL_SPINS_BEFORE_CHECK_ABORT) { - abort = *ncclShmem->comm.abortFlag; + abort = __atomic_load_n(ncclShmem->comm.abortFlag, __ATOMIC_SEQ_CST); spins = 0; } return abort; @@ -74,67 +78,69 @@ class Primitives: if (sendConnHeadPtr) { int spins = 0; while (sendConnHeadCache + NCCL_STEPS < sendConnHead + 1) { - sendConnHeadCache = *sendConnHeadPtr; + __builtin_amdgcn_s_sleep(8); + sendConnHeadCache = atomicAdd_system((unsigned long long *)sendConnHeadPtr, 0); if (checkAbort(spins, wid, 1)) break; } + __asm__ __volatile__("s_wakeup"); if (sendConnFifoPtr) { - sendConnFifoPtr[sendStep[wid]%NCCL_STEPS] = nbytes; + __atomic_store_n(sendConnFifoPtr+sendStep[wid]%NCCL_STEPS, nbytes, __ATOMIC_SEQ_CST); } sendConnHead += 1; } } inline __device__ void postRecv() { - if (recvConnHeadPtr) *recvConnHeadPtr = recvConnHead += 1; + if (recvConnHeadPtr) atomicExch_system((unsigned long long *)recvConnHeadPtr, recvConnHead += 1); } inline __device__ void postSend() { - if (sendConnTailPtr) { __threadfence(); *sendConnTailPtr = sendConnTail += 1; } + if (sendConnTailPtr) { __threadfence(); atomicExch_system((unsigned long long *)sendConnTailPtr, sendConnTail += 1); } } template __device__ __forceinline__ void loadRegsBegin(uint64_t(®s)[WordPerThread], T const *src, int eltN) { constexpr int EltPer16B = 16/sizeof(T); - if(reinterpret_cast(src)%16 == 0) { - /* We are aligned to 16 bytes, so load directly to registers no shmem. - * Flag threads load half as much data which gets shuffled to the even - * registers during Finish. The point of splitting into two phases is to - * defer that shuffle, which incurs a dependency stall, until after other - * memops are launched by the caller. - */ - #pragma unroll - for(int g=0; g < WordPerThread/2; g++) { - int ix = g*WARP_SIZE - 4*(g/2) + wid - (g%2)*(wid/8); - if(!flagThread || g%2==0) { - if(ix*EltPer16B < eltN) - load128((uint64_t*)(src + ix*EltPer16B), regs[2*g+0], regs[2*g+1]); - } - } - } - else { - // Not aligned. Stage the smallest 16 byte aligned region subsuming the - // buffer into shmem. - int misalignment = reinterpret_cast(src) % 16; - uint64_t *src8 = reinterpret_cast(reinterpret_cast(src) & -uintptr_t(16)); - uint64_t *shm8 = shmemCvtPtr(ncclShmem->ll128warp[warp]); - #pragma unroll - for(int g=0; g < WordPerThread/2; g++) - if((g*WARP_SIZE + wid)*16 < misalignment + eltN*sizeof(T)) - load128(src8 + 2*(g*WARP_SIZE + wid), regs[2*g+0], regs[2*g+1]); - #pragma unroll - for(int g=0; g < WordPerThread/2; g++) - storeShmem128(shm8 + 2*(g*WARP_SIZE + wid), regs[2*g+0], regs[2*g+1]); - - __syncwarp(); - - // Now load from shmem stage to regs. Preserve the same pre-shuffled layout - // as the aligned case since Finish() will be applied regardless. - T *shm = (T*)shm8 + misalignment/sizeof(T); - #pragma unroll - for(int g=0; g < WordPerThread/2; g++) { - int ix = g*WARP_SIZE - 4*(g/2) + wid - (g%2)*(wid/8); - if(!flagThread || g%2==0) { - if(ix*EltPer16B < eltN) - loadShmemMisaligned128(shm + ix*EltPer16B, regs[2*g+0], regs[2*g+1]); + /* We are aligned to 16 bytes, so load directly to registers no shmem. + * Flag threads load half as much data which gets shuffled to the even + * registers during Finish. The point of splitting into two phases is to + * defer that shuffle, which incurs a dependency stall, until after other + * memops are launched by the caller. + */ + #pragma unroll + for(int g=0; g < WordPerThread/2; g++) { + int ix = g*WARP_SIZE - 16*(g/2) + wid - (g%2)*(wid/4); + if(!flagThread || g%2==0) { + if(ix*EltPer16B < eltN) { + if(reinterpret_cast(src)%4 == 0) { + regs[2*g+0] = __builtin_nontemporal_load((uint64_t*)(src + ix*EltPer16B)); + regs[2*g+1] = __builtin_nontemporal_load((uint64_t*)(src + ix*EltPer16B)+1); + } else { + union { + uint64_t regs64[WordPerThread]; + uint32_t regs32[WordPerThread*2]; + uint16_t regs16[WordPerThread*4]; + uint8_t regs8[WordPerThread*8]; + }; + if (sizeof(T) == 8) { + uint64_t *src64 = (uint64_t*)(src+ix*EltPer16B); + for (int i=0; i < 2; i++) + regs64[2*g+i] = __builtin_nontemporal_load(src64+i); + } else if (sizeof(T) == 4) { + uint32_t *src32 = (uint32_t*)(src+ix*EltPer16B); + for (int i=0; i < 2*sizeof(uint64_t)/sizeof(T); i++) + regs32[2*g+i] = __builtin_nontemporal_load(src32+i); + } else if (sizeof(T) == 2) { + uint16_t *src16 = (uint16_t*)(src+ix*EltPer16B); + for (int i=0; i < 2*sizeof(uint64_t)/sizeof(T); i++) + regs16[2*g+i] = __builtin_nontemporal_load(src16+i); + } else if (sizeof(T) == 1) { + uint8_t *src8 = (uint8_t*)(src+ix*EltPer16B); + for (int i=0; i < 2*sizeof(uint64_t)/sizeof(T); i++) + regs8[2*g+i] = __builtin_nontemporal_load(src8+i); + } + regs[2*g+0] = regs64[2*g+0]; + regs[2*g+1] = regs64[2*g+1]; + } } } } @@ -157,26 +163,45 @@ class Primitives: for (int g=1; g < WordPerThread/2; g+=2) { if (flagThread) regs[2*g-1] = regs[2*g]; } - // Write to dst if 16-byte aligned, shmem otherwise. - int misalignment = reinterpret_cast(dst)%16; - uint64_t *shm8 = shmemCvtPtr(ncclShmem->ll128warp[warp]); + // Write to dst if 4-byte aligned, shmem otherwise. + int misalignment = reinterpret_cast(dst)%4; #pragma unroll for(int g=0; g < WordPerThread/2; g++) { - int ix = g*WARP_SIZE - 4*(g/2) + wid - (g%2)*(wid/8); + int ix = g*WARP_SIZE - 16*(g/2) + wid - (g%2)*(wid/4); if (!flagThread || g%2==0) { - if(misalignment == 0 && (ix+1)*EltPer16B <= eltN) - store128((uint64_t*)(dst + ix*EltPer16B), regs[2*g+0], regs[2*g+1]); - else - storeShmem128(shm8+2*ix, regs[2*g+0], regs[2*g+1]); + if(misalignment == 0 && (ix+1)*EltPer16B <= eltN) { + __builtin_nontemporal_store(regs[2*g+0], (uint64_t*)(dst + ix*EltPer16B)); + __builtin_nontemporal_store(regs[2*g+1], (uint64_t*)(dst + ix*EltPer16B)+1); + } else { + union { + uint64_t regs64[WordPerThread]; + uint32_t regs32[WordPerThread*2]; + uint16_t regs16[WordPerThread*4]; + uint8_t regs8[WordPerThread*8]; + }; + regs64[2*g+0] = regs[2*g+0]; + regs64[2*g+1] = regs[2*g+1]; + int remaining = eltN - ix*EltPer16B; + if (sizeof(T) == 8) { + uint64_t *dst64 = (uint64_t*)(dst+ix*EltPer16B); + for (int i=0; i < 2 && i < remaining; i++) + __builtin_nontemporal_store(regs64[2*g+i], dst64+i); + } else if (sizeof(T) == 4) { + uint32_t *dst32 = (uint32_t*)(dst+ix*EltPer16B); + for (int i=0; i < 2*sizeof(uint64_t)/sizeof(T) && i < remaining; i++) + __builtin_nontemporal_store(regs32[2*g+i], dst32+i); + } else if (sizeof(T) == 2) { + uint16_t *dst16 = (uint16_t*)(dst+ix*EltPer16B); + for (int i=0; i < 2*sizeof(uint64_t)/sizeof(T) && i < remaining; i++) + __builtin_nontemporal_store(regs16[2*g+i], dst16+i); + } else if (sizeof(T) == 1) { + uint8_t *dst8 = (uint8_t*)(dst+ix*EltPer16B); + for (int i=0; i < 2*sizeof(uint64_t)/sizeof(T) && i < remaining; i++) + __builtin_nontemporal_store(regs8[2*g+i], dst8+i); + } + } } } - __syncwarp(); - // Write rest from shmem to dst. No need to coalesce stores to 16-bytes, - // the hardware keeps up fine. - T *shm = (T*)ncclShmem->ll128warp[warp]; - int skip = misalignment == 0 ? eltN & -EltPer16B : 0; - for(int i=skip+wid; i < eltN; i += WARP_SIZE) - dst[i] = shm[i]; } #define WARP_MASK 0xffffffff @@ -197,10 +222,11 @@ class Primitives: needReload = false; #pragma unroll for (int u=0; u: needReload = false; #pragma unroll for (int u=0; u: } } +#if !defined(__gfx1030__) + if (tid == 0) __asm__ __volatile__("buffer_wbinvl1_vol"); +#endif /************************ Send **************************/ if (SEND) { for (int i=1; i: uint64_t* ptr = sendPtr(i)+ll128Offset; #pragma unroll for (int u=0; u: public: __device__ Primitives( const int tid, const int nthreads, int const *recvPeers, int const *sendPeers, - void const *inputBuf, void *outputBuf, uint64_t redOpArg, int group=0, int connIndex=0 + void const *inputBuf, void *outputBuf, uint64_t redOpArg, int group=0 ): redOp(redOpArg), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE), - flagThread((tid%8)==7), group(group), + flagThread((tid%4)==3), group(group&(uint16_t)0xFFFF), stepSize(ncclShmem->comm.buffSizes[NCCL_PROTO_LL128]/NCCL_STEPS/sizeof(uint64_t)) { + barriers = &ncclShmem->groups[this->group].barrier; + barrier_next = ncclShmem->groups[this->group].barrier_next; auto *channel = &ncclShmem->channel; int nrecv=0, nsend=0; diff --git a/src/enqueue.cc b/src/enqueue.cc index e641dfbd4f..71e77e978b 100644 --- a/src/enqueue.cc +++ b/src/enqueue.cc @@ -68,10 +68,13 @@ NCCL_FUNCS3B(func, Sum), /*PreMulSum*/ \ NCCL_FUNCS3B(func, Sum) /*SumPostDiv*/ -typedef void(*ncclKern_t)(struct ncclDevComm* comm, struct ncclWorkElem first); +typedef void(*ncclKern_t)(struct ncclDevComm* comm); // Must be consistent with the ncclFuncSet enum -static ncclKern_t const ncclKerns[1] = { +static ncclKern_t const ncclKerns[4] = { NCCL_KERN_NAME(SendRecv, RING, SIMPLE, Sum, int8_t), + NCCL_KERN_NAME_DEBUG(SendRecv, RING, SIMPLE, Sum, int8_t), + NCCL_KERN_NAME_LL128(SendRecv, RING, SIMPLE, Sum, int8_t), + NCCL_KERN_NAME_LL128_DEBUG(SendRecv, RING, SIMPLE, Sum, int8_t), }; // Determine the maximum kernel stack size of all CUDA kernels @@ -89,6 +92,19 @@ error: return (res != ncclSuccess) ? 0 : max; } +// Determine kernel stack size from index +size_t ncclKernLocalSize(int i) { + ncclResult_t res = ncclSuccess; + int numNcclKerns = sizeof(ncclKerns)/sizeof(ncclKerns[0]); + hipFuncAttributes attr = {0}; + if (i < numNcclKerns) + CUDACHECKGOTO(hipFuncGetAttributes(&attr, (const void*)(ncclKerns[i])), res, error); + +error: + return (res != ncclSuccess) ? 0 : attr.localSizeBytes; +} + + // Set shared memory carveout for the nccl kernels ncclResult_t ncclKernSetSharedMemoryCarveout(int carveOut) { ncclResult_t res = ncclSuccess; @@ -174,14 +190,6 @@ static ncclResult_t setupLaunch(struct ncclQueueInfo* eqInfo, int usingCudaGraph } channel->workFifo[(channel->workFifoTail-1)%NCCL_MAX_OPS].header.isLast = 1; - if (c == 0) { - // As we inline the first coll directly, we can free it immediately. - // Except P2P or aggregation or registration cases - struct ncclWork* work = channel->workFifo+((channel->workFifoTail-channel->workCount)%NCCL_MAX_OPS); - if (work->header.type == ncclWorkTypeColl && eqInfo->elemList->count() == 1) - work->header.type = ncclWorkTypeUnused; - } - if (channel->gdrMemDesc) { // GDRCOPY support uint64_t first = (channel->workFifoTail-channel->workCount)%NCCL_MAX_OPS; @@ -762,14 +770,7 @@ static ncclResult_t ncclSetupCollKernel(struct ncclInfo* info) { // Inline the first kernel if (params->func == NULL) { - params->func = (void *)ncclKerns[0]; - if (work->header.type == ncclWorkTypeColl) { - // Copy the first operation to the inline argument. Type may be set later to - // ncclWorkTypeUnused if we have more than one coll element. - memcpy(&comm->args, work->elems, sizeof(struct ncclWorkElem)); - comm->args.bid = 0; // Only inline for channel 0 - comm->args.header.isLast = 1; // I am so far the last element - } + params->func = (void *)ncclKerns[ncclGetKernelIndex(comm)]; } // Register and exchange input and output buffers @@ -781,9 +782,6 @@ static ncclResult_t ncclSetupCollKernel(struct ncclInfo* info) { NCCLCHECK(ncclRegBuffAndExchange(info, &eqElem->buffRegInfo)); comm->enqueueInfo->nRegBuffs += eqElem->buffRegInfo.nBuffs; work->header.type = ncclWorkTypeRegColl; - // Disable inline argument because we need kernel to copy the entire ncclWork from workFifo - // because the registered addresses are in ncclWorkElemReg - comm->args.header.type = ncclWorkTypeUnused; } return ncclSuccess; @@ -883,7 +881,6 @@ ncclResult_t ncclSetupAsyncKernels(ncclComm_t comm) { } NCCLCHECK(ncclSetupCollKernel(info)); } - comm->args.header.type = ncclWorkTypeUnused; // disable inline argument } // Reset counters comm->asyncOpCount = 0; @@ -1106,9 +1103,8 @@ ncclResult_t ncclSetupP2pKernel(struct ncclInfo* info) { // Just for CUDA kernel to know this is a P2P operation // The CUDA kernel does not use the inlined first work element as fastpath argument if (params->func == NULL) { - params->func = (void *)ncclKerns[0]; + params->func = (void *)ncclKerns[ncclGetKernelIndex(comm)]; //params->func = ncclKerns[eqElem->work.header.funcIndex]; - comm->args.header.type = ncclWorkTypeUnused; } return ncclSuccess; } diff --git a/src/graph/rome_models.cc b/src/graph/rome_models.cc index 66c0d3b6ee..3858f7574a 100644 --- a/src/graph/rome_models.cc +++ b/src/graph/rome_models.cc @@ -838,6 +838,8 @@ static void parseOptions(struct ncclTopoSystem* system, const char *options) { system->pivotA2ANumBiRings = atol(tokens[i*2+1]); } else if (strcmp(tokens[i*2], "tuning") == 0) { system->tuning = atol(tokens[i*2+1]); + } else if (strcmp(tokens[i*2], "ll128Enabled") == 0) { + system->ll128Enabled = (bool)atol(tokens[i*2+1]); } } free(str_temp); diff --git a/src/graph/topo.h b/src/graph/topo.h index d088cee3f7..1ed433b4a6 100644 --- a/src/graph/topo.h +++ b/src/graph/topo.h @@ -165,6 +165,7 @@ struct ncclTopoSystem { bool pivotA2AEnabled; int pivotA2ANumBiRings; + bool ll128Enabled; }; ncclResult_t ncclTopoGetNode(struct ncclTopoSystem* system, struct ncclTopoNode** node, int type, uint64_t id); @@ -209,4 +210,8 @@ static ncclResult_t ncclTopoRankToIndex(struct ncclTopoSystem* system, int rank, static float ncclTopoXGMISpeed(int gcn) { return gcn == 910 ? MI200_XGMI_WIDTH : VEGA_XGMI_WIDTH; } + +#define ncclGetKernelIndex(p_comm) \ + (((p_comm)->topo->ll128Enabled ? 1 : 0)*2 + ((p_comm)->hostDevComm.collTraceThread ? 1 : 0)) + #endif diff --git a/src/graph/tuning.cc b/src/graph/tuning.cc index 2246155229..d0a115b32d 100644 --- a/src/graph/tuning.cc +++ b/src/graph/tuning.cc @@ -191,30 +191,30 @@ static struct tuningModel tuning_model_3 { static struct tuningModel tuning_model_4 { .hwLat = { /* NVLINK */ - { /* Tree (LL/LL128/Simple)*/ { 0.8, 0.0, 2.5 }, /* Ring (LL/LL128/Simple)*/ { 0.8, 0.0, 3.6 }, /* CollNet (LL/LL128/Simple)*/ { 0.8, 0.0, 2.5 } }, + { /* Tree (LL/LL128/Simple)*/ { 0.8, 1.4, 2.5 }, /* Ring (LL/LL128/Simple)*/ { 0.8, 2.2, 3.6 }, /* CollNet (LL/LL128/Simple)*/ { 0.8, 1.4, 2.5 } }, /* PCI */ { /* Tree (LL/LL128/Simple)*/ { 2.2, 2.2, 5.7 }, /* Ring (LL/LL128/Simple)*/ { 2.2, 2.2, 5.7 }, /* CollNet (LL/LL128/Simple)*/ { 2.2, 2.2, 5.7 } }, /* NET */ - { /* Tree (LL/LL128/Simple)*/ { 45.8, 0.0, 105.0 }, /* Ring (LL/LL128/Simple)*/ { 19.2, 0.0, 51.0 }, /* CollNet (LL/LL128/Simple)*/ { 45.8, 0.0, 105.0 } }, + { /* Tree (LL/LL128/Simple)*/ { 45.8, 62.5, 105.0 }, /* Ring (LL/LL128/Simple)*/ { 19.2, 44.6, 51.0 }, /* CollNet (LL/LL128/Simple)*/ { 45.8, 62.5, 105.0 } }, }, .bwRatio = { /* 2 nodes */ - { /* Tree (LL/LL128/Simple)*/ { 0.12, 0.00, 1.41 }, /* Ring (LL/LL128/Simple)*/ { 0.12, 0.00, 1.00 }, /* CollNet (LL/LL128/Simple)*/ { 1.00, 1.00, 1.00 } }, + { /* Tree (LL/LL128/Simple)*/ { 0.12, 0.21, 1.41 }, /* Ring (LL/LL128/Simple)*/ { 0.12, 0.26, 1.00 }, /* CollNet (LL/LL128/Simple)*/ { 1.00, 1.00, 1.00 } }, /* more than 2 nodes */ - { /* Tree (LL/LL128/Simple)*/ { 0.12, 0.00, 1.05 }, /* Ring (LL/LL128/Simple)*/ { 0.12, 0.00, 1.00 }, /* CollNet (LL/LL128/Simple)*/ { 1.00, 1.00, 1.00 } }, + { /* Tree (LL/LL128/Simple)*/ { 0.12, 0.21, 1.05 }, /* Ring (LL/LL128/Simple)*/ { 0.12, 0.26, 1.00 }, /* CollNet (LL/LL128/Simple)*/ { 1.00, 1.00, 1.00 } }, }, .treeCorrectionFactor = { - { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.4, 1.0, 0.5, 0.8, 0.4, 0.3, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1, }, - { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, }, - { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.3, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9, 0.9, 0.8, 0.8, 0.8, }, + { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.1, 0.2, 0.1, 0.1, 0.1, 0.2, 0.3, 0.1, 0.2, 0.2, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, }, + { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.4, 0.1, 0.1, 0.2, 0.3, 0.4, 0.2, 0.3, 0.3, 0.4, 0.3, 0.3, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, }, + { 0.1, 0.5, 0.1, 0.4, 0.1, 0.1, 0.2, 0.1, 1.0, 0.3, 0.1, 0.1, 0.1, 1.0, 0.6, 1.0, 1.0, 1.0, 1.0, 0.9, 0.8, 0.7, 0.6, 0.7, 0.6, 0.8, 0.8, }, }, .ringCorrectionFactor = { - { 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.6, 0.2, 0.2, 0.2, 0.2, 0.2, 0.1, 0.2, 0.2, 0.3, 0.3, 0.3, 0.2, 0.2, 0.2, 0.2, 0.2, 0.1, }, - { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, }, - { 0.6, 0.4, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 1.0, 0.8, 1.0, 1.0, 1.0, 0.7, 0.2, 0.3, 0.4, 0.5, 0.6, 0.8, 0.9, 1.0, 1.0, 1.0, 1.0, 1.0, }, + { 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.4, 0.1, 0.3, 0.1, 0.1, 0.2, 0.2, 0.3, 0.1, 0.2, 0.3, 0.3, 0.2, 0.2, 0.1, 0.2, 0.1, 0.1, 0.1, 0.1, }, + { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 0.2, 0.4, 1.0, 1.0, 1.0, 0.8, 0.6, 0.2, 0.4, 0.6, 0.5, 0.5, 0.4, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, }, + { 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 0.2, 0.9, 0.9, 0.8, 1.0, 0.4, 0.6, 0.7, 0.9, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, }, }, }; @@ -362,9 +362,16 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom for (int c=0; ctypeInter <= PATH_PXB) && graphs[a]->typeIntra <= PATH_NVL && + (comm->topo->nodes[GPU].nodes[0].gpu.gcn == 910 && comm->topo->ll128Enabled) ? 1 : 0; +#else // Enable LL128 by default only on Volta/Ampere+NVLink. Other cases are not tested and may cause silent data corruption. pEnable = (graphs[a]->typeInter <= PATH_PXB) && graphs[a]->typeIntra <= PATH_NVL && ((minCompCap == 70 && maxCompCap == 70) || (minCompCap == 80 && maxCompCap == 80)) ? 1 : 0; +#endif + if (comm->rank == 0 && c == 0 && a == 0) INFO(NCCL_INIT, "Using tuning table %d with LL128 %s", comm->topo->tuning, pEnable ? "enabled" : "disabled"); } if (pEnable == 0) comm->bandwidths[c][a][p] = 0; // Only disable algo for Allreduce since others only have one diff --git a/src/include/collectives.h b/src/include/collectives.h index 4f3f55bed1..c1285b6dd7 100644 --- a/src/include/collectives.h +++ b/src/include/collectives.h @@ -32,13 +32,25 @@ struct ncclDevRedOpFull { #define NCCL_KERN_NAME(func, algo, proto, devredop, type) \ ncclKernel_##func##_##algo##_##proto##_##devredop##_##type +#define NCCL_KERN_NAME_DEBUG(func, algo, proto, devredop, type) \ + ncclKernelDebug_##func##_##algo##_##proto##_##devredop##_##type + +#define NCCL_KERN_NAME_LL128(func, algo, proto, devredop, type) \ + ncclKernelLL128_##func##_##algo##_##proto##_##devredop##_##type + +#define NCCL_KERN_NAME_LL128_DEBUG(func, algo, proto, devredop, type) \ + ncclKernelLL128Debug_##func##_##algo##_##proto##_##devredop##_##type + #define NCCL_IMPL_NAME(func, algo, proto) \ nccl##func##algo##proto /* Declare all collective operations */ #define DECL5(func, algo, proto, devredop, type) \ extern __device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, devredop, type)(); \ - extern __global__ void NCCL_KERN_NAME(func, algo, proto, devredop, type)(struct ncclDevComm* comm, struct ncclWorkElem c); \ + extern __global__ void NCCL_KERN_NAME(func, algo, proto, devredop, type)(struct ncclDevComm* comm); \ + extern __global__ void NCCL_KERN_NAME_DEBUG(func, algo, proto, devredop, type)(struct ncclDevComm* comm); \ + extern __global__ void NCCL_KERN_NAME_LL128(func, algo, proto, devredop, type)(struct ncclDevComm* comm); \ + extern __global__ void NCCL_KERN_NAME_LL128_DEBUG(func, algo, proto, devredop, type)(struct ncclDevComm* comm); #define CONCAT(a,b) a##b #define MACRO_IF(cond, t, f) CONCAT(MACRO_IF_, cond)(t, f) diff --git a/src/include/comm.h b/src/include/comm.h index 359b7666b1..c7f04d44a7 100644 --- a/src/include/comm.h +++ b/src/include/comm.h @@ -197,8 +197,7 @@ struct ncclComm { int* intraCudaDevs; int* intraCGMode; // Whether we can use CUDA9 CGMD or not int* intraCC; // Only to check all have the same ComputeCap and disable CGMode if not - struct ncclWorkElem args; - void* argsptrs[2]; + void* argsptrs[1]; struct ncclProxyState proxyState; diff --git a/src/include/devcomm.h b/src/include/devcomm.h index 4f3887fe25..0a8bd5da43 100644 --- a/src/include/devcomm.h +++ b/src/include/devcomm.h @@ -76,18 +76,18 @@ union ncclLLFifoLine { // Make sure the clean mask will last for at least NCCL_NSTEPS static_assert(NCCL_LL_CLEAN_MASK % NCCL_STEPS == 0, "Invalid NCCL_LL_CLEAN_MASK value"); -#define NCCL_LL128_LINESIZE 128 +#define NCCL_LL128_LINESIZE 64 #define NCCL_LL128_LINEELEMS (NCCL_LL128_LINESIZE/sizeof(uint64_t)) #define NCCL_LL128_DATAELEMS (NCCL_LL128_LINEELEMS-1) #define NCCL_LL128_MAX_NTHREADS 256 -#define NCCL_LL128_ELEMS_PER_THREAD 120 +#define NCCL_LL128_ELEMS_PER_THREAD 28 // Receiving from up to 3 sources is more compute intensive than sending // to 3 dests. Use 70% for reduce and 30% for bcast. #define NCCL_LL128_SPLIT(nt) ((nt*7/(10*32))*32) -#define NCCL_LL128_SHMEM_ELEMS_PER_THREAD 2 +#define NCCL_LL128_SHMEM_ELEMS_PER_THREAD 4 #define NCCL_LL128_SHMEM_SIZE (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*NCCL_LL128_MAX_NTHREADS) #define NCCL_DIRECT_WRITE 0x01 diff --git a/src/include/enqueue.h b/src/include/enqueue.h index d538a1da77..882386b6ae 100644 --- a/src/include/enqueue.h +++ b/src/include/enqueue.h @@ -16,6 +16,7 @@ #define NCCL_AGG_CHANNEL_SIZE (1LL << 21) /* 2 MiB, ideal per-channel size to fully utilize bandwidth */ size_t ncclKernMaxLocalSize(); +size_t ncclKernLocalSize(int i); ncclResult_t ncclKernSetSharedMemoryCarveout(int carveOut); ncclResult_t ncclEnqueueCheck(struct ncclInfo* info); ncclResult_t ncclCpuBarrierIn(struct ncclComm* comm, int* isLast); diff --git a/src/init.cc b/src/init.cc index 8f9092de7d..31cdf595b2 100644 --- a/src/init.cc +++ b/src/init.cc @@ -333,6 +333,7 @@ static ncclResult_t commFree(ncclComm_t comm) { RCCL_PARAM(CliqueIgnoreTopo, "CLIQUE_IGNORE_TOPO", 0); RCCL_PARAM(P2pNetDisable, "P2P_NET_DISABLE", 0); RCCL_PARAM(PivotAlltoallEnable, "PIVOT_ALLTOALL_ENABLE", 1); +RCCL_PARAM(LL128ForceEnable, "LL128_FORCE_ENABLE", 0); NCCL_PARAM(AggChannelSize, "AGG_CHANNEL_SIZE", -2); NCCL_PARAM(DisableGraphHelper, "GRAPH_HELPER_DISABLE", 0); NCCL_PARAM(GraphRegister, "GRAPH_REGISTER", 0); @@ -383,7 +384,6 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank, int virtua comm->p2pOpCount = 0; comm->argsptrs[0] = &comm->devComm; - comm->argsptrs[1] = &comm->args; #ifdef ENABLE_PROFILING NCCLCHECK(ncclCudaCalloc(&comm->hostDevComm.devProf, MAXCHANNELS*PROFILE_NUM_LAUNCHES)); #endif @@ -703,6 +703,8 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm // init Pivot A2A related fields comm->topo->pivotA2AEnabled = false; comm->topo->pivotA2ANumBiRings = 0; + // LL128 + comm->topo->ll128Enabled = false; // Compute paths between GPUs and NICs NCCLCHECK(ncclTopoComputePaths(comm->topo, comm->peerInfo)); // Remove inaccessible GPUs and unused NICs @@ -845,6 +847,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm struct ncclGraphInfo collNet; struct ncclTopoRanks topoRanks; bool pivotA2AEnabled; + bool ll128Enabled; } *allGather3Data; NCCLCHECK(ncclCalloc(&allGather3Data, nranks)); @@ -893,6 +896,8 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm allGather3Data[rank].collNet.typeInter = collNetGraph.typeInter; allGather3Data[rank].collNetSupport = comm->collNetSupport; allGather3Data[rank].pivotA2AEnabled = comm->topo->pivotA2AEnabled && rcclParamPivotAlltoallEnable(); + comm->topo->ll128Enabled = comm->topo->ll128Enabled || rcclParamLL128ForceEnable(); + allGather3Data[rank].ll128Enabled = comm->topo->ll128Enabled; comm->nChannels = (comm->topo->nodes[GPU].count != comm->topo->nRanks && comm->topo->nodes[NET].count) ? std::min(treeGraph.nChannels, ringGraph.nChannels) : ringGraph.nChannels; @@ -979,6 +984,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm collNetGraph.typeInter = std::max(allGather3Data[i].collNet.typeInter, collNetGraph.typeInter); comm->collNetSupport = std::min(allGather3Data[i].collNetSupport, comm->collNetSupport); comm->topo->pivotA2AEnabled = comm->topo->pivotA2AEnabled && allGather3Data[i].pivotA2AEnabled; + comm->topo->ll128Enabled = comm->topo->ll128Enabled && allGather3Data[i].ll128Enabled; } comm->nChannels = treeGraph.nChannels = ringGraph.nChannels = @@ -1239,7 +1245,7 @@ ncclResult_t ncclCommInitRankSync(ncclComm_t* newcomm, int nranks, ncclUniqueId NCCLCHECKGOTO(initTransportsRank(*newcomm, &commId), res, cleanup); NCCLCHECKGOTO(devCommSetup(*newcomm), res, cleanup); - INFO(NCCL_INIT,"comm %p rank %d nranks %d cudaDev %d busId %lx localSize %ld used %ld bytes - Init COMPLETE", *newcomm, myrank, nranks, (*newcomm)->cudaDev, (*newcomm)->busId, maxLocalSizeBytes, allocTracker[(*newcomm)->cudaDev].totalAllocSize); + INFO(NCCL_INIT,"comm %p rank %d nranks %d cudaDev %d busId %lx localSize %ld used %ld bytes - Init COMPLETE", *newcomm, myrank, nranks, (*newcomm)->cudaDev, (*newcomm)->busId, ncclKernLocalSize(ncclGetKernelIndex(*newcomm)), allocTracker[(*newcomm)->cudaDev].totalAllocSize); return ncclSuccess; cleanup: