Support different protocols and algorithms in all reduce only build (#455)
* Support different protocols and algorithms in all reduce only build
* Restore deleted line in error
[ROCm/rccl commit: 29170a8b5f]
This commit is contained in:
@@ -119,14 +119,22 @@ include_directories(src/include)
|
||||
include_directories(src/collectives)
|
||||
include_directories(src/collectives/device)
|
||||
|
||||
set(CU_SOURCES
|
||||
src/collectives/device/all_reduce.cu
|
||||
src/collectives/device/all_gather.cu
|
||||
src/collectives/device/reduce.cu
|
||||
src/collectives/device/broadcast.cu
|
||||
src/collectives/device/reduce_scatter.cu
|
||||
src/collectives/device/sendrecv.cu
|
||||
src/collectives/device/functions.cu)
|
||||
if (BUILD_ALLREDUCE_ONLY)
|
||||
add_definitions(-DBUILD_ALLREDUCE_ONLY)
|
||||
set(CU_SOURCES
|
||||
src/collectives/device/all_reduce.cu
|
||||
src/collectives/device/sendrecv.cu
|
||||
src/collectives/device/functions.cu)
|
||||
else()
|
||||
set(CU_SOURCES
|
||||
src/collectives/device/all_reduce.cu
|
||||
src/collectives/device/all_gather.cu
|
||||
src/collectives/device/reduce.cu
|
||||
src/collectives/device/broadcast.cu
|
||||
src/collectives/device/reduce_scatter.cu
|
||||
src/collectives/device/sendrecv.cu
|
||||
src/collectives/device/functions.cu)
|
||||
endif()
|
||||
|
||||
set(CPP_SOURCES)
|
||||
foreach(filename ${CU_SOURCES})
|
||||
@@ -209,10 +217,6 @@ if(COLLTRACE)
|
||||
add_definitions(-DENABLE_COLLTRACE)
|
||||
endif()
|
||||
|
||||
if (BUILD_ALLREDUCE_ONLY)
|
||||
add_definitions(-DBUILD_ALLREDUCE_ONLY)
|
||||
endif()
|
||||
|
||||
CHECK_INCLUDE_FILE_CXX("${ROCM_PATH}/rocm_smi/include/rocm_smi/rocm_smi64Config.h" HAVE_ROCM_SMI64CONFIG)
|
||||
IF(HAVE_ROCM_SMI64CONFIG)
|
||||
add_definitions(-DUSE_ROCM_SMI64CONFIG)
|
||||
|
||||
@@ -113,7 +113,7 @@ static const __device__ constexpr ncclKernelFunc_t ncclFuncs[]{
|
||||
// confuses clang. This will be fixed in the next clang release.
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
#if defined(BUILD_ALLREDUCE_ONLY)
|
||||
NCCL_FUNC_NAME(AllReduce, RING, SIMPLE, Sum, float)
|
||||
NCCL_FUNC4B(AllReduce, Sum, float),
|
||||
#else
|
||||
NCCL_FUNCS2B(Broadcast),
|
||||
NCCL_FUNCS2A(Reduce),
|
||||
@@ -148,8 +148,22 @@ inline
|
||||
__device__
|
||||
void NCCL_CALL_FUNCTIONS(struct ncclWorkElem* const c) noexcept {
|
||||
#if defined(BUILD_ALLREDUCE_ONLY)
|
||||
assert(c->funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE));
|
||||
ncclFunction_AllReduce_RING_SIMPLE_Sum_float(c);
|
||||
if (c->funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE))
|
||||
ncclFunction_AllReduce_RING_SIMPLE_Sum_float(c);
|
||||
else if (c->funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_LL))
|
||||
ncclFunction_AllReduce_RING_LL_Sum_float(c);
|
||||
else if (c->funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_LL128))
|
||||
ncclFunction_AllReduce_RING_LL128_Sum_float(c);
|
||||
else if (c->funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_TREE, NCCL_PROTO_SIMPLE))
|
||||
ncclFunction_AllReduce_TREE_SIMPLE_Sum_float(c);
|
||||
else if (c->funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_TREE, NCCL_PROTO_LL))
|
||||
ncclFunction_AllReduce_TREE_LL_Sum_float(c);
|
||||
else if (c->funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET, NCCL_PROTO_SIMPLE))
|
||||
ncclFunction_AllReduce_COLLNET_SIMPLE_Sum_float(c);
|
||||
else if (c->funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET, NCCL_PROTO_LL))
|
||||
ncclFunction_AllReduce_COLLNET_LL_Sum_float(c);
|
||||
else
|
||||
assert("Unsupported function index");
|
||||
#else
|
||||
if (c->funcIndex < 450) {
|
||||
if (c->funcIndex % 9 == 0) ncclFunction_Broadcast_TREE_LL_Sum_int8_t(c);
|
||||
|
||||
Reference in New Issue
Block a user