Remove workaround and use indirect function call (#684)

[ROCm/rccl commit: f7a456122c]
Этот коммит содержится в:
Wenkai Du
2023-02-14 13:59:48 -08:00
коммит произвёл GitHub
родитель cb7e2e8eeb
Коммит 4fb1ebcf4b
12 изменённых файлов: 100 добавлений и 8 удалений
+10 -2
Просмотреть файл
@@ -463,13 +463,21 @@ foreach(target ${AMDGPU_TARGETS})
endforeach()
if("${HIP_COMPILER}" MATCHES "clang")
target_compile_options(rccl PRIVATE -fvisibility=hidden --hipcc-func-supp)
find_program( hipcc_executable hipcc )
execute_process(COMMAND bash "-c" "${hipcc_executable} --version | grep 'HIP version' | awk -F\" \" '{ printf $3}' | awk -F\"-\" '{ printf $1}'" OUTPUT_VARIABLE hipcc_version_string)
message(STATUS "hipcc version: ${hipcc_version_string}")
if(${hipcc_version_string} VERSION_GREATER_EQUAL "5.5.30201")
add_definitions(-DUSE_INDIRECT_FUNCTION_CALL)
target_compile_options(rccl PRIVATE -fvisibility=hidden)
message(STATUS "Indirect function call enabled")
else()
target_compile_options(rccl PRIVATE -fvisibility=hidden --hipcc-func-supp)
endif()
foreach(target ${AMDGPU_TARGETS})
target_compile_options(rccl PRIVATE -fgpu-rdc)
endforeach()
target_link_libraries(rccl PRIVATE -fgpu-rdc)
target_include_directories(rccl PRIVATE ${ROCM_PATH}/include)
find_program( hipcc_executable hipcc )
execute_process(COMMAND bash "-c" "${hipcc_executable} -help | grep 'parallel-jobs'" OUTPUT_VARIABLE hipcc_parallel_jobs)
if("${hipcc_parallel_jobs}" MATCHES "parallel-jobs")
target_compile_options(rccl PRIVATE -parallel-jobs=12 PRIVATE -Wno-format-nonliteral)
+4
Просмотреть файл
@@ -11,7 +11,11 @@
namespace {
template<typename T, typename RedOp, typename Proto>
#ifdef USE_INDIRECT_FUNCTION_CALL
__device__ void runRing(ncclWorkElem *args) {
#else
__device__ __attribute__((noinline)) void runRing(ncclWorkElem *args) {
#endif
const int tid = threadIdx.x;
const int nthreads = args->nWarps*WARP_SIZE;
const int bid = args->bid;
+12
Просмотреть файл
@@ -15,7 +15,11 @@
namespace {
template<typename T, typename RedOp, typename Proto>
#ifdef USE_INDIRECT_FUNCTION_CALL
__device__ void runRing(ncclWorkElem *args) {
#else
__device__ __attribute__((noinline)) void runRing(ncclWorkElem *args) {
#endif
const int tid = threadIdx.x;
const int nthreads = args->nWarps*WARP_SIZE;
const int bid = args->bid;
@@ -219,7 +223,11 @@ namespace {
}
template<typename T, typename RedOp, typename Proto>
#ifdef USE_INDIRECT_FUNCTION_CALL
__device__ void runTreeUpDown(ncclWorkElem *args) {
#else
__device__ __attribute__((noinline)) void runTreeUpDown(ncclWorkElem *args) {
#endif
const int tid = threadIdx.x;
const int nthreads = args->nWarps*WARP_SIZE;
const int bid = args->bid;
@@ -371,7 +379,11 @@ namespace {
}
template<typename T, typename RedOp, typename Proto>
#ifdef USE_INDIRECT_FUNCTION_CALL
__device__ void runTreeSplit(ncclWorkElem *args) {
#else
__device__ __attribute__((noinline)) void runTreeSplit(ncclWorkElem *args) {
#endif
const int tid = threadIdx.x;
const int nthreads = args->nWarps*WARP_SIZE;
const int bid = args->bid;
+4
Просмотреть файл
@@ -10,7 +10,11 @@
namespace {
template<typename T, typename RedOp, typename Proto>
#ifdef USE_INDIRECT_FUNCTION_CALL
__device__ void runRing(ncclWorkElem *args) {
#else
__device__ __attribute__((noinline)) void runRing(ncclWorkElem *args) {
#endif
const int tid = threadIdx.x;
const int nthreads = args->nWarps*WARP_SIZE;
const int bid = args->bid;
+4
Просмотреть файл
@@ -10,7 +10,11 @@
namespace {
template<typename T, typename RedOp, typename Proto>
#ifdef USE_INDIRECT_FUNCTION_CALL
__device__ void runRing(ncclWorkElem *args) {
#else
__device__ __attribute__((noinline)) void runRing(ncclWorkElem *args) {
#endif
const int tid = threadIdx.x;
const int nthreads = args->nWarps*WARP_SIZE;
const int bid = args->bid;
+22 -3
Просмотреть файл
@@ -199,6 +199,10 @@ static const __device__ constexpr ncclKernelFunc_t ncclFuncs_ll128[]{
#endif
};
static_assert(FUNC_INDEX_P2P == 3610, "Wrong P2P function index");
static_assert(FUNC_INDEX_ALLTOALL_PIVOT == 3611, "Wrong AllToAllPivot function index");
#ifndef USE_INDIRECT_FUNCTION_CALL
template<unsigned short f, unsigned short l, bool u>
struct Caller {
static __forceinline__ __device__ __host__
@@ -216,9 +220,6 @@ struct Caller<f, f + 1, u>{
void call(unsigned short funcIndex) noexcept { if (u) ncclFuncs_ll128[f](); else ncclFuncs[f](); }
};
static_assert(FUNC_INDEX_P2P == 3610, "Wrong P2P function index");
static_assert(FUNC_INDEX_ALLTOALL_PIVOT == 3611, "Wrong AllToAllPivot function index");
template<bool USING_LL128>
__forceinline__
__device__
@@ -340,11 +341,16 @@ void NCCL_CALL_FUNCTIONS(unsigned short funcIndex) noexcept {
}
#endif
}
#endif
template <ncclFunc_t FUNCTION, int ALGO, int PROTO, class REDOP, typename T, int UNROLL>
class ncclFunction {
public:
#ifdef USE_INDIRECT_FUNCTION_CALL
__device__ __attribute__((noinline)) void run(struct ncclWorkElem* args) {}
#else
__device__ void run(struct ncclWorkElem* args) {}
#endif
};
#ifdef ENABLE_COLLTRACE
@@ -663,7 +669,12 @@ __forceinline__ __device__ void ncclKernel(
if (ncclShmem.work.header.funcIndex == FnIndex) {
RunWork<Fn, T, RedOp, Algo, Proto>().run(&ncclShmem.work);
} else {
#ifdef USE_INDIRECT_FUNCTION_CALL
if (USING_LL128) ncclFuncs_ll128[ncclShmem.work.header.funcIndex]();
else ncclFuncs[ncclShmem.work.header.funcIndex]();
#else
NCCL_CALL_FUNCTIONS<USING_LL128>(ncclShmem.work.header.funcIndex);
#endif
}
int workIxNext = ncclShmem.work.header.workNext;
@@ -714,10 +725,18 @@ __global__ void NCCL_KERN_NAME_LL128_DEBUG(func, algo, proto, devredop, type)(st
// Examples : AllReduce, RING, LL, Sum, uint8
/* Functions for aggregation case */
#ifdef USE_INDIRECT_FUNCTION_CALL
#define IMPL_COLL_FUNC(func, algo, proto, devredop, type) \
__device__ void NCCL_FUNC_NAME(func, algo, proto, devredop, type)() { \
RunWork<ncclFunc##func, type, Func##devredop<type>, NCCL_ALGO_##algo, NCCL_PROTO_##proto>().run(&ncclShmem.work); \
}
#else
#define IMPL_COLL_FUNC(func, algo, proto, devredop, type) \
__device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, devredop, type)() { \
RunWork<ncclFunc##func, type, Func##devredop<type>, NCCL_ALGO_##algo, NCCL_PROTO_##proto>().run(&ncclShmem.work); \
}
#endif
// Only generate inline kernels for LL
#define IMPL_COLL4(func, algo, devredop, type, ncclType) \
+11
Просмотреть файл
@@ -12,7 +12,11 @@
namespace {
template<typename T, typename RedOp>
#ifdef USE_INDIRECT_FUNCTION_CALL
__device__ void oneRankReduce() {
#else
__device__ __attribute__((noinline)) void oneRankReduce() {
#endif
ncclWork *w = &ncclShmem.work;
int tid = threadIdx.x;
int tn = blockDim.x;
@@ -42,10 +46,17 @@ namespace {
}
}
#ifdef USE_INDIRECT_FUNCTION_CALL
#define INSTANTIATE(devredop, type) \
__device__ void NCCL_ONERANK_REDUCE_NAME(devredop, type)() { \
oneRankReduce<type, Func##devredop<type>>(); \
}
#else
#define INSTANTIATE(devredop, type) \
__device__ __attribute__((noinline)) void NCCL_ONERANK_REDUCE_NAME(devredop, type)() { \
oneRankReduce<type, Func##devredop<type>>(); \
}
#endif
INSTANTIATE(PreMulSum, int8_t)
INSTANTIATE(PreMulSum, uint8_t)
+4
Просмотреть файл
@@ -11,7 +11,11 @@
namespace {
template<typename T, typename RedOp, typename Proto>
#ifdef USE_INDIRECT_FUNCTION_CALL
__device__ void runRing(ncclWorkElem *args) {
#else
__device__ __attribute__((noinline)) void runRing(ncclWorkElem *args) {
#endif
const int tid = threadIdx.x;
const int nthreads = args->nWarps*WARP_SIZE;
const int bid = args->bid;
+4
Просмотреть файл
@@ -11,7 +11,11 @@
namespace {
template<typename T, typename RedOp, typename Proto>
#ifdef USE_INDIRECT_FUNCTION_CALL
__device__ void runRing(ncclWorkElem *args) {
#else
__device__ __attribute__((noinline)) void runRing(ncclWorkElem *args) {
#endif
const int tid = threadIdx.x;
const int nthreads = args->nWarps*WARP_SIZE;
const int bid = args->bid;
+4
Просмотреть файл
@@ -175,7 +175,11 @@ struct RunWork<ncclFuncSendRecv, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
}
}
#ifdef USE_INDIRECT_FUNCTION_CALL
__device__ void run(ncclWork *work) {
#else
__device__ __attribute__((noinline)) void run(ncclWork *work) {
#endif
struct ncclWorkElemP2p* args = work->p2pElems;
int ngroups = args->ngroups;
int tid = threadIdx.x;
+9
Просмотреть файл
@@ -45,12 +45,21 @@ struct ncclDevRedOpFull {
nccl##func##algo##proto
/* Declare all collective operations */
#ifdef USE_INDIRECT_FUNCTION_CALL
#define DECL5(func, algo, proto, devredop, type) \
extern __device__ void NCCL_FUNC_NAME(func, algo, proto, devredop, type)(); \
extern __global__ void NCCL_KERN_NAME(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead); \
extern __global__ void NCCL_KERN_NAME_DEBUG(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead); \
extern __global__ void NCCL_KERN_NAME_LL128(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead); \
extern __global__ void NCCL_KERN_NAME_LL128_DEBUG(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead);
#else
#define DECL5(func, algo, proto, devredop, type) \
extern __device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, devredop, type)(); \
extern __global__ void NCCL_KERN_NAME(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead); \
extern __global__ void NCCL_KERN_NAME_DEBUG(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead); \
extern __global__ void NCCL_KERN_NAME_LL128(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead); \
extern __global__ void NCCL_KERN_NAME_LL128_DEBUG(func, algo, proto, devredop, type)(struct ncclDevComm* comm, uint64_t channelMask, struct ncclWork* workHead);
#endif
#define SINGLE_ARG(...) __VA_ARGS__
#define CONCAT(a,b) a##b
+12 -3
Просмотреть файл
@@ -1414,7 +1414,13 @@ fail:
goto exit;
}
#ifdef USE_INDIRECT_FUNCTION_CALL
NCCL_PARAM(SetStackSize, "SET_STACK_SIZE", 1);
RCCL_PARAM(StackSizeOverride, "STACK_SIZE_OVERRIDE", 8);
#else
NCCL_PARAM(SetStackSize, "SET_STACK_SIZE", 0);
RCCL_PARAM(StackSizeOverride, "STACK_SIZE_OVERRIDE", 0);
#endif
struct ncclCommInitRankAsyncJob {
struct ncclAsyncJob base;
@@ -1440,14 +1446,17 @@ static ncclResult_t ncclCommInitRankFunc(struct ncclAsyncJob* job_) {
int cudaDev = job->cudaDev;
int virtualId = job->virtualId;
ncclResult_t res = ncclSuccess;
int64_t stackSize = rcclParamStackSizeOverride() ? rcclParamStackSizeOverride() : maxLocalSizeBytes;
CUDACHECKGOTO(cudaSetDevice(cudaDev), res, fail);
// Set the maximum kernel stack size of all kernels to avoid
// a CUDA memory reconfig on load (c.f. NVSHMEM issue)
if (maxLocalSizeBytes > 0 && ncclParamSetStackSize() == 1) {
TRACE(NCCL_INIT, "Setting cudaLimitStackSize to %zi", maxLocalSizeBytes);
//CUDACHECKIGNORE(cudaDeviceSetLimit(cudaLimitStackSize, maxLocalSizeBytes));
#ifdef USE_INDIRECT_FUNCTION_CALL
if (stackSize > 0 && ncclParamSetStackSize() == 1) {
INFO(NCCL_INIT, "Setting cudaLimitStackSize to %zi maxLocalSizeBytes %zi", stackSize, maxLocalSizeBytes);
CUDACHECKIGNORE(cudaDeviceSetLimit(cudaLimitStackSize, stackSize));
}
#endif
NCCLCHECKGOTO(commAlloc(newcomm, nranks, myrank, virtualId), res, fail);
NCCLCHECKGOTO(initTransportsRank(*newcomm, &commId), res, fail);