Support different protocols and algorithms in all reduce only build (#455)

* Support different protocols and algorithms in all reduce only build

* Restore deleted line in error

[ROCm/rccl commit: 29170a8b5f]
This commit is contained in:
Wenkai Du
2021-11-02 08:39:08 -07:00
committed by GitHub
orang tua a11b55a37f
melakukan df59f64e3f
2 mengubah file dengan 33 tambahan dan 15 penghapusan
+16 -12
Melihat File
@@ -119,14 +119,22 @@ include_directories(src/include)
include_directories(src/collectives)
include_directories(src/collectives/device)
set(CU_SOURCES
src/collectives/device/all_reduce.cu
src/collectives/device/all_gather.cu
src/collectives/device/reduce.cu
src/collectives/device/broadcast.cu
src/collectives/device/reduce_scatter.cu
src/collectives/device/sendrecv.cu
src/collectives/device/functions.cu)
if (BUILD_ALLREDUCE_ONLY)
add_definitions(-DBUILD_ALLREDUCE_ONLY)
set(CU_SOURCES
src/collectives/device/all_reduce.cu
src/collectives/device/sendrecv.cu
src/collectives/device/functions.cu)
else()
set(CU_SOURCES
src/collectives/device/all_reduce.cu
src/collectives/device/all_gather.cu
src/collectives/device/reduce.cu
src/collectives/device/broadcast.cu
src/collectives/device/reduce_scatter.cu
src/collectives/device/sendrecv.cu
src/collectives/device/functions.cu)
endif()
set(CPP_SOURCES)
foreach(filename ${CU_SOURCES})
@@ -209,10 +217,6 @@ if(COLLTRACE)
add_definitions(-DENABLE_COLLTRACE)
endif()
if (BUILD_ALLREDUCE_ONLY)
add_definitions(-DBUILD_ALLREDUCE_ONLY)
endif()
CHECK_INCLUDE_FILE_CXX("${ROCM_PATH}/rocm_smi/include/rocm_smi/rocm_smi64Config.h" HAVE_ROCM_SMI64CONFIG)
IF(HAVE_ROCM_SMI64CONFIG)
add_definitions(-DUSE_ROCM_SMI64CONFIG)
@@ -113,7 +113,7 @@ static const __device__ constexpr ncclKernelFunc_t ncclFuncs[]{
// confuses clang. This will be fixed in the next clang release.
#if defined(__HIP_DEVICE_COMPILE__)
#if defined(BUILD_ALLREDUCE_ONLY)
NCCL_FUNC_NAME(AllReduce, RING, SIMPLE, Sum, float)
NCCL_FUNC4B(AllReduce, Sum, float),
#else
NCCL_FUNCS2B(Broadcast),
NCCL_FUNCS2A(Reduce),
@@ -148,8 +148,22 @@ inline
__device__
void NCCL_CALL_FUNCTIONS(struct ncclWorkElem* const c) noexcept {
#if defined(BUILD_ALLREDUCE_ONLY)
assert(c->funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE));
ncclFunction_AllReduce_RING_SIMPLE_Sum_float(c);
if (c->funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE))
ncclFunction_AllReduce_RING_SIMPLE_Sum_float(c);
else if (c->funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_LL))
ncclFunction_AllReduce_RING_LL_Sum_float(c);
else if (c->funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_LL128))
ncclFunction_AllReduce_RING_LL128_Sum_float(c);
else if (c->funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_TREE, NCCL_PROTO_SIMPLE))
ncclFunction_AllReduce_TREE_SIMPLE_Sum_float(c);
else if (c->funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_TREE, NCCL_PROTO_LL))
ncclFunction_AllReduce_TREE_LL_Sum_float(c);
else if (c->funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET, NCCL_PROTO_SIMPLE))
ncclFunction_AllReduce_COLLNET_SIMPLE_Sum_float(c);
else if (c->funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET, NCCL_PROTO_LL))
ncclFunction_AllReduce_COLLNET_LL_Sum_float(c);
else
assert("Unsupported function index");
#else
if (c->funcIndex < 450) {
if (c->funcIndex % 9 == 0) ncclFunction_Broadcast_TREE_LL_Sum_int8_t(c);