diff --git a/projects/rccl/CMakeLists.txt b/projects/rccl/CMakeLists.txt index 98bbb85645..e5e62e56b0 100644 --- a/projects/rccl/CMakeLists.txt +++ b/projects/rccl/CMakeLists.txt @@ -119,14 +119,22 @@ include_directories(src/include) include_directories(src/collectives) include_directories(src/collectives/device) -set(CU_SOURCES - src/collectives/device/all_reduce.cu - src/collectives/device/all_gather.cu - src/collectives/device/reduce.cu - src/collectives/device/broadcast.cu - src/collectives/device/reduce_scatter.cu - src/collectives/device/sendrecv.cu - src/collectives/device/functions.cu) +if (BUILD_ALLREDUCE_ONLY) + add_definitions(-DBUILD_ALLREDUCE_ONLY) + set(CU_SOURCES + src/collectives/device/all_reduce.cu + src/collectives/device/sendrecv.cu + src/collectives/device/functions.cu) +else() + set(CU_SOURCES + src/collectives/device/all_reduce.cu + src/collectives/device/all_gather.cu + src/collectives/device/reduce.cu + src/collectives/device/broadcast.cu + src/collectives/device/reduce_scatter.cu + src/collectives/device/sendrecv.cu + src/collectives/device/functions.cu) +endif() set(CPP_SOURCES) foreach(filename ${CU_SOURCES}) @@ -209,10 +217,6 @@ if(COLLTRACE) add_definitions(-DENABLE_COLLTRACE) endif() -if (BUILD_ALLREDUCE_ONLY) - add_definitions(-DBUILD_ALLREDUCE_ONLY) -endif() - CHECK_INCLUDE_FILE_CXX("${ROCM_PATH}/rocm_smi/include/rocm_smi/rocm_smi64Config.h" HAVE_ROCM_SMI64CONFIG) IF(HAVE_ROCM_SMI64CONFIG) add_definitions(-DUSE_ROCM_SMI64CONFIG) diff --git a/projects/rccl/src/collectives/device/common.h b/projects/rccl/src/collectives/device/common.h index f0edb0aa3a..c11f35e882 100644 --- a/projects/rccl/src/collectives/device/common.h +++ b/projects/rccl/src/collectives/device/common.h @@ -113,7 +113,7 @@ static const __device__ constexpr ncclKernelFunc_t ncclFuncs[]{ // confuses clang. This will be fixed in the next clang release. #if defined(__HIP_DEVICE_COMPILE__) #if defined(BUILD_ALLREDUCE_ONLY) - NCCL_FUNC_NAME(AllReduce, RING, SIMPLE, Sum, float) + NCCL_FUNC4B(AllReduce, Sum, float), #else NCCL_FUNCS2B(Broadcast), NCCL_FUNCS2A(Reduce), @@ -148,8 +148,22 @@ inline __device__ void NCCL_CALL_FUNCTIONS(struct ncclWorkElem* const c) noexcept { #if defined(BUILD_ALLREDUCE_ONLY) - assert(c->funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE)); - ncclFunction_AllReduce_RING_SIMPLE_Sum_float(c); + if (c->funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE)) + ncclFunction_AllReduce_RING_SIMPLE_Sum_float(c); + else if (c->funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_LL)) + ncclFunction_AllReduce_RING_LL_Sum_float(c); + else if (c->funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_LL128)) + ncclFunction_AllReduce_RING_LL128_Sum_float(c); + else if (c->funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_TREE, NCCL_PROTO_SIMPLE)) + ncclFunction_AllReduce_TREE_SIMPLE_Sum_float(c); + else if (c->funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_TREE, NCCL_PROTO_LL)) + ncclFunction_AllReduce_TREE_LL_Sum_float(c); + else if (c->funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET, NCCL_PROTO_SIMPLE)) + ncclFunction_AllReduce_COLLNET_SIMPLE_Sum_float(c); + else if (c->funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET, NCCL_PROTO_LL)) + ncclFunction_AllReduce_COLLNET_LL_Sum_float(c); + else + assert("Unsupported function index"); #else if (c->funcIndex < 450) { if (c->funcIndex % 9 == 0) ncclFunction_Broadcast_TREE_LL_Sum_int8_t(c);