From 37bf54b8f88296a56d12a666dee7f6ec936a62fd Mon Sep 17 00:00:00 2001 From: corey-derochie-amd <161367113+corey-derochie-amd@users.noreply.github.com> Date: Thu, 4 Jul 2024 09:34:38 -0600 Subject: [PATCH] Enable multi-threading for MSCCL (#1203) MSCCL can now run in a multi-threaded configuration. To test in the unit tests, added the ENABLE_OPENMP compile definition flag and the --openmp-test-enable flag to the unit test build script. To activate, set the environment variables UT_MULTITHREADED=1 and UT_PROCESS_MASK=1. Set Jenkins to use this mode. [ROCm/rccl commit: 0c36d571eadd9361aed924aef6ad3dfe010c5db0] --- projects/rccl/.jenkins/common.groovy | 2 +- projects/rccl/install.sh | 10 +- projects/rccl/src/collectives.cc | 22 ++-- .../rccl/src/include/msccl/msccl_lifecycle.h | 4 +- projects/rccl/src/include/msccl/msccl_setup.h | 4 +- .../rccl/src/include/msccl/msccl_status.h | 10 +- projects/rccl/src/init.cc | 4 +- .../rccl/src/misc/msccl/msccl_lifecycle.cc | 96 ++++++++--------- projects/rccl/src/misc/msccl/msccl_setup.cc | 27 +++-- projects/rccl/src/misc/msccl/msccl_status.cc | 62 +++++++++-- projects/rccl/src/msccl.cc | 23 +--- projects/rccl/src/nccl.h.in | 1 + projects/rccl/test/CMakeLists.txt | 17 +++ projects/rccl/test/common/EnvVars.cpp | 4 +- projects/rccl/test/common/EnvVars.hpp | 1 + projects/rccl/test/common/ErrCode.hpp | 34 ++++-- projects/rccl/test/common/TestBed.cpp | 2 +- projects/rccl/test/common/TestBedChild.cpp | 101 +++++++++++++----- projects/rccl/test/common/TestBedChild.hpp | 3 +- 19 files changed, 279 insertions(+), 148 deletions(-) diff --git a/projects/rccl/.jenkins/common.groovy b/projects/rccl/.jenkins/common.groovy index ea9f4778ae..22d39257bf 100644 --- a/projects/rccl/.jenkins/common.groovy +++ b/projects/rccl/.jenkins/common.groovy @@ -24,7 +24,7 @@ def runTestCommand (platform, project, gfilter, envars) cd ${project.paths.project_build_prefix}/build/release/test ${sudo} ulimit -l unlimited ulimit -a - ${sudo} ${envars} RCCL_ENABLE_SIGNALHANDLER=1 NCCL_DEBUG=INFO HSA_FORCE_FINE_GRAIN_PCIE=1 ./rccl-UnitTests --gtest_filter=${gfilter} --gtest_output=xml --gtest_color=yes + ${sudo} ${envars} RCCL_ENABLE_SIGNALHANDLER=1 NCCL_DEBUG=INFO HSA_FORCE_FINE_GRAIN_PCIE=1 UT_MULTITHREAD=1 UT_PROCESS_MASK=1 ./rccl-UnitTests --gtest_filter=${gfilter} --gtest_output=xml --gtest_color=yes """ platform.runCommand(this, command) diff --git a/projects/rccl/install.sh b/projects/rccl/install.sh index f8cd414007..9b9b2ce62c 100755 --- a/projects/rccl/install.sh +++ b/projects/rccl/install.sh @@ -26,6 +26,7 @@ install_prefix="${ROCM_PATH}" msccl_kernel_enabled=true num_parallel_jobs=$(nproc) npkit_enabled=false +openmp_test_enabled=false roctx_enabled=false run_tests=false run_tests_all=false @@ -52,6 +53,7 @@ function display_help() echo " --amdgpu_targets Only compile for specified GPU architecture(s). For multiple targets, seperate by ';' (builds for all supported GPU architectures by default)" echo " --no_clean Don't delete files if they already exist" echo " --npkit-enable Compile with npkit enabled" + echo " --openmp-test-enable Enable OpenMP in rccl unit tests" echo " --roctx-enable Compile with roctx enabled (example usage: rocprof --roctx-trace ./rccl-program)" echo " -p|--package_build Build RCCL package" echo " --prefix Specify custom directory to install RCCL to (default: \`/opt/rocm\`)" @@ -71,7 +73,7 @@ function display_help() # check if we have a modern version of getopt that can handle whitespace and long parameters getopt -T if [[ "$?" -eq 4 ]]; then - GETOPT_PARSE=$(getopt --name "${0}" --options dfhij:lprt --longoptions address-sanitizer,dependencies,debug,enable_backtrace,disable-colltrace,disable-msccl-kernel,fast,help,install,jobs:,local_gpu_only,amdgpu_targets:,no_clean,npkit-enable,roctx-enable,package_build,prefix:,rm-legacy-include-dir,run_tests_all,run_tests_quick,static,tests_build,time-trace,verbose -- "$@") + GETOPT_PARSE=$(getopt --name "${0}" --options dfhij:lprt --longoptions address-sanitizer,dependencies,debug,enable_backtrace,disable-colltrace,disable-msccl-kernel,fast,help,install,jobs:,local_gpu_only,amdgpu_targets:,no_clean,npkit-enable,openmp-test-enable,roctx-enable,package_build,prefix:,rm-legacy-include-dir,run_tests_all,run_tests_quick,static,tests_build,time-trace,verbose -- "$@") else echo "Need a new version of getopt" exit 1 @@ -100,6 +102,7 @@ while true; do --amdgpu_targets) build_amdgpu_targets=${2}; shift 2 ;; --no_clean) clean_build=false; shift ;; --npkit-enable) npkit_enabled=true; shift ;; + --openmp-test-enable) openmp_test_enabled=true; shift ;; --roctx-enable) roctx_enabled=true; shift ;; -p | --package_build) build_package=true; shift ;; --prefix) install_library=true; install_prefix=${2}; shift 2 ;; @@ -246,6 +249,11 @@ if [[ "${roctx_enabled}" == true ]]; then cmake_common_options="${cmake_common_options} -DROCTX=ON" fi +# Enable OpenMP in unit tests +if [[ "${openmp_test_enabled}" == true ]]; then + cmake_common_options="${cmake_common_options} -DOPENMP_TESTS_ENABLED=ON" +fi + # Enable NPKit npkit_options="" if [[ "${npkit_enabled}" == true ]]; then diff --git a/projects/rccl/src/collectives.cc b/projects/rccl/src/collectives.cc index 3fa2f83d9d..896feaac51 100644 --- a/projects/rccl/src/collectives.cc +++ b/projects/rccl/src/collectives.cc @@ -23,7 +23,7 @@ ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcoun size_t msgsize = sendcount * ncclTypeSize(datatype); NVTX3_FUNC_WITH_PARAMS(AllGather, AllGatherSchema, msgsize) - if (mscclAvailable() && !mscclIsCaller()) { + if (mscclAvailable(comm->rank) && !mscclIsCaller()) { return mscclEnqueueCheck( sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr, sendcount, datatype, 0, 0, ncclSum, mscclFuncAllGather, comm, stream); @@ -52,7 +52,7 @@ ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count, NvtxParamsAllReduce payload{count * ncclTypeSize(datatype), op}; NVTX3_FUNC_WITH_PARAMS(AllReduce, AllReduceSchema, payload) - if (mscclAvailable() && !mscclIsCaller()) { + if (mscclAvailable(comm->rank) && !mscclIsCaller()) { return mscclEnqueueCheck( sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr, count, datatype, 0, 0, op, mscclFuncAllReduce, comm, stream); @@ -75,7 +75,7 @@ ncclResult_t ncclAllToAll(const void* sendbuff, void* recvbuff, size_t count, nc size_t msgsize = count * ncclTypeSize(datatype); NVTX3_FUNC_WITH_PARAMS(AllToAll, AllToAllSchema, msgsize) - if (mscclAvailable() && !mscclIsCaller()) { + if (mscclAvailable(comm->rank) && !mscclIsCaller()) { return mscclEnqueueCheck( sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr, count, datatype, 0, 0, ncclSum, mscclFuncAllToAll, comm, stream); @@ -122,7 +122,7 @@ ncclResult_t ncclAllToAllv(const void *sendbuff, const size_t sendcounts[], cons NvtxParamsAllToAllv payload{sendcounts[comm->rank] * ncclTypeSize(datatype), recvcounts[comm->rank] * ncclTypeSize(datatype)}; NVTX3_FUNC_WITH_PARAMS(AllToAllv, AllToAllvSchema, payload) - if (mscclAvailable() && !mscclIsCaller()) { + if (mscclAvailable(comm->rank) && !mscclIsCaller()) { return mscclEnqueueCheck( sendbuff, sendcounts, sdispls, recvbuff, recvcounts, rdispls, 0, datatype, 0, 0, ncclSum, mscclFuncAllToAllv, comm, stream); @@ -166,7 +166,7 @@ ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t count, n NvtxParamsBroadcast payload{count * ncclTypeSize(datatype), root}; NVTX3_FUNC_WITH_PARAMS(Broadcast, BroadcastSchema, payload) - if (mscclAvailable() && !mscclIsCaller()) { + if (mscclAvailable(comm->rank) && !mscclIsCaller()) { return mscclEnqueueCheck( sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr, count, datatype, root, 0, ncclSum, mscclFuncBroadcast, comm, stream); @@ -200,7 +200,7 @@ ncclResult_t ncclGather(const void* sendbuff, void* recvbuff, size_t sendcount, NvtxParamsGather payload{sendcount * ncclTypeSize(datatype), root}; NVTX3_FUNC_WITH_PARAMS(Gather, GatherSchema, payload) - if (mscclAvailable() && !mscclIsCaller()) { + if (mscclAvailable(comm->rank) && !mscclIsCaller()) { return mscclEnqueueCheck( sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr, sendcount, datatype, root, 0, ncclSum, mscclFuncGather, comm, stream); @@ -240,7 +240,7 @@ ncclResult_t ncclReduce(const void* sendbuff, void* recvbuff, size_t count, NvtxParamsReduce payload{count * ncclTypeSize(datatype), root, op}; NVTX3_FUNC_WITH_PARAMS(Reduce, ReduceSchema, payload) - if (mscclAvailable() && !mscclIsCaller()) { + if (mscclAvailable(comm->rank) && !mscclIsCaller()) { return mscclEnqueueCheck( sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr, count, datatype, root, 0, op, mscclFuncReduce, comm, stream); @@ -268,7 +268,7 @@ ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, size_t recv NvtxParamsReduceScatter payload{recvcount * ncclTypeSize(datatype), op}; NVTX3_FUNC_WITH_PARAMS(ReduceScatter, ReduceScatterSchema, payload) - if (mscclAvailable() && !mscclIsCaller()) { + if (mscclAvailable(comm->rank) && !mscclIsCaller()) { return mscclEnqueueCheck( sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr, recvcount, datatype, 0, 0, op, mscclFuncReduceScatter, comm, stream); @@ -295,7 +295,7 @@ ncclResult_t ncclScatter(const void* sendbuff, void* recvbuff, size_t recvcount, NvtxParamsScatter payload{recvcount * ncclTypeSize(datatype), root}; NVTX3_FUNC_WITH_PARAMS(Scatter, ScatterSchema, payload) - if (mscclAvailable() && !mscclIsCaller()) { + if (mscclAvailable(comm->rank) && !mscclIsCaller()) { return mscclEnqueueCheck( sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr, recvcount, datatype, root, 0, ncclSum, mscclFuncScatter, comm, stream); @@ -333,7 +333,7 @@ ncclResult_t ncclSend(const void* sendbuff, size_t count, ncclDataType_t datatyp NvtxParamsSendRecv payload{count * ncclTypeSize(datatype), peer}; NVTX3_FUNC_WITH_PARAMS(Send, SendRecvSchema, payload) - if (mscclAvailable() && !mscclIsCaller()) { + if (mscclAvailable(comm->rank) && !mscclIsCaller()) { return mscclEnqueueCheck( sendbuff, nullptr, nullptr, nullptr, nullptr, nullptr, count, datatype, 0, peer, ncclSum, mscclFuncSend, comm, stream); @@ -356,7 +356,7 @@ ncclResult_t ncclRecv(void* recvbuff, size_t count, ncclDataType_t datatype, int NvtxParamsSendRecv payload{count * ncclTypeSize(datatype), peer}; NVTX3_FUNC_WITH_PARAMS(Recv, SendRecvSchema, payload) - if (mscclAvailable() && !mscclIsCaller()) { + if (mscclAvailable(comm->rank) && !mscclIsCaller()) { return mscclEnqueueCheck( nullptr, nullptr, nullptr, recvbuff, nullptr, nullptr, count, datatype, 0, peer, ncclSum, mscclFuncRecv, comm, stream); diff --git a/projects/rccl/src/include/msccl/msccl_lifecycle.h b/projects/rccl/src/include/msccl/msccl_lifecycle.h index eac4f3a7ac..459a046613 100644 --- a/projects/rccl/src/include/msccl/msccl_lifecycle.h +++ b/projects/rccl/src/include/msccl/msccl_lifecycle.h @@ -17,7 +17,7 @@ void mscclSetIsCallerFlag(); void mscclClearIsCallerFlag(); bool mscclIsCaller(); -bool mscclAvailable(); +bool mscclAvailable(int rank = -1); ncclResult_t mscclSchedulerInit(ncclComm_t comm, int* numChannelsRequired); @@ -33,7 +33,7 @@ ncclResult_t mscclEnqueueCheck( ncclResult_t mscclGroupEnd(); -ncclResult_t mscclTeardown(); +ncclResult_t mscclTeardown(int rank); size_t mscclKernMaxLocalSize(); diff --git a/projects/rccl/src/include/msccl/msccl_setup.h b/projects/rccl/src/include/msccl/msccl_setup.h index 7a326236ce..b3f88e35ec 100644 --- a/projects/rccl/src/include/msccl/msccl_setup.h +++ b/projects/rccl/src/include/msccl/msccl_setup.h @@ -11,11 +11,11 @@ #include "comm.h" #include "msccl/msccl_struct.h" -ncclResult_t mscclGetCaptureStatus(hipStream_t stream); +ncclResult_t mscclGetCaptureStatus(int rank, hipStream_t stream); ncclResult_t mscclSetupScratch(struct mscclAlgo* hostAlgo, hipStream_t stream); -ncclResult_t mscclSetupSyncFlags(hipStream_t stream); +ncclResult_t mscclSetupSyncFlags(int rank, hipStream_t stream); ncclResult_t mscclSetupConnections(struct mscclAlgo* hostAlgo, ncclComm_t comm); diff --git a/projects/rccl/src/include/msccl/msccl_status.h b/projects/rccl/src/include/msccl/msccl_status.h index a8bfb09ed3..077709ddc8 100644 --- a/projects/rccl/src/include/msccl/msccl_status.h +++ b/projects/rccl/src/include/msccl/msccl_status.h @@ -8,9 +8,15 @@ #include "msccl/msccl_struct.h" -mscclStatus& mscclGetStatus(); +bool mscclInitialized(int rank); -mscclSavedProxyArgs& mscclGetSavedProxyArgs(); +void mscclSetInitialized(int rank, bool initialized = true); + +void mscclRemoveRank(int rank); + +mscclStatus& mscclGetStatus(int rank); + +mscclSavedProxyArgs& mscclGetSavedProxyArgs(int rank); mscclThreadLocalStatus& mscclGetThreadLocalStatus(); diff --git a/projects/rccl/src/init.cc b/projects/rccl/src/init.cc index 63725bda9c..92feb355bd 100644 --- a/projects/rccl/src/init.cc +++ b/projects/rccl/src/init.cc @@ -1738,7 +1738,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p if (mscclEnabled() && (comm->topo->mscclEnabled || mscclForceEnabled())) { NCCLCHECK(mscclInit(comm)); - mscclStatus& status = mscclGetStatus(); + mscclStatus& status = mscclGetStatus(comm->rank); status.needsProxy |= mscclNeedsProxy; } @@ -2351,7 +2351,7 @@ static ncclResult_t commCleanup(ncclComm_t comm) { #endif if (mscclEnabled() && (mscclEnabledForTopo || mscclForceEnabled())) { - NCCLCHECK(mscclTeardown()); + NCCLCHECK(mscclTeardown(comm->rank)); } return ncclSuccess; diff --git a/projects/rccl/src/misc/msccl/msccl_lifecycle.cc b/projects/rccl/src/misc/msccl/msccl_lifecycle.cc index 7b8f1f4c41..60ec97804b 100644 --- a/projects/rccl/src/misc/msccl/msccl_lifecycle.cc +++ b/projects/rccl/src/misc/msccl/msccl_lifecycle.cc @@ -25,8 +25,6 @@ RCCL_PARAM(MscclEnabled, "MSCCL_ENABLE", 1); RCCL_PARAM(MscclForceEnabled, "MSCCL_FORCE_ENABLE", 0); static const char* mscclAlgoFilePathEnv = "MSCCL_ALGO_FILE_PATH"; -static std::atomic mscclInitialized; -static std::mutex mscclLifecycleMutex; bool mscclEnabled() { #ifdef COMPILE_MSCCL_KERNEL @@ -56,23 +54,12 @@ bool mscclIsCaller() { return mscclGetThreadLocalStatus().mscclIsCallerFlag; } -bool mscclAvailable() { - return mscclEnabled() && mscclInitialized.load(std::memory_order_acquire); +bool mscclAvailable(int rank) { + return mscclEnabled() && mscclInitialized(rank); } static bool mscclCommCompatible(ncclComm_t comm) { - std::map> hostHashToPidHashes; - for (int i = 0; i < comm->nRanks; i++) { - uint64_t hostHash = comm->peerInfo[i].hostHash; - uint64_t pidHash = comm->peerInfo[i].pidHash; - if (hostHashToPidHashes.find(hostHash) != hostHashToPidHashes.end()) { - auto& pidHashSet = hostHashToPidHashes[hostHash]; - if (pidHashSet.find(pidHash) != pidHashSet.end()) { - return false; - } - } - hostHashToPidHashes[hostHash].insert(pidHash); - } + // MSCCL is always compatible now. No need to guard against multi-thread. return true; } @@ -86,8 +73,8 @@ static const char* mscclAlgoShareDirPath = "../share/rccl/msccl-algorithms"; static const char* mscclUnitTestAlgoShareDirPath = "../share/rccl/msccl-unit-test-algorithms"; static ncclResult_t mscclInternalSchedulerInit(ncclComm_t comm, int* numChannelsRequired) { - static bool mscclAlgoMetaLoaded = false; - mscclStatus& status = mscclGetStatus(); + static thread_local bool mscclAlgoMetaLoaded = false; + mscclStatus& status = mscclGetStatus(comm->rank); *numChannelsRequired = 0; // Query numChannelsRequired from loaded algorithm metas @@ -166,9 +153,7 @@ ncclResult_t mscclSchedulerInit(ncclComm_t comm, int* numChannelsRequired) { return ncclSuccess; } - std::lock_guard lock(mscclLifecycleMutex); - - mscclStatus& status = mscclGetStatus(); + mscclStatus& status = mscclGetStatus(comm->rank); bool useInternalScheduler = false; const char* mscclSchedulerPath = getenv(mscclSchedulerPathEnv); @@ -199,20 +184,11 @@ ncclResult_t mscclSchedulerInit(ncclComm_t comm, int* numChannelsRequired) { } ncclResult_t mscclInit(ncclComm_t comm) { - // Always initialize thread local status - mscclThreadLocalStatus threadLocalStatus = mscclGetThreadLocalStatus(); - threadLocalStatus.groupStatus = mscclNoGroup; - threadLocalStatus.groupDepth = 0; - threadLocalStatus.captureId = ULLONG_MAX; - threadLocalStatus.captureStatus = mscclNoCapture; - { - std::lock_guard lock(mscclLifecycleMutex); - - mscclStatus& status = mscclGetStatus(); + mscclStatus& status = mscclGetStatus(comm->rank); // freeAlgoHandles and needsProxy are initialized globally once and before algorithm pre-processing and connection - if (!mscclInitialized.load(std::memory_order_acquire)) { + if (!mscclInitialized(comm->rank)) { status.freeAlgoHandles.resize(MSCCL_MAX_NUM_ALGOS); for (int i = 0; i < MSCCL_MAX_NUM_ALGOS; i++) { status.freeAlgoHandles[i] = MSCCL_MAX_NUM_ALGOS - i - 1; @@ -229,6 +205,8 @@ ncclResult_t mscclInit(ncclComm_t comm) { if (m.nRanks == comm->nRanks) { // Load algorithms if (status.rankToAlgoHandles[i].find(comm->rank) == status.rankToAlgoHandles[i].end()) { + static std::mutex loadAlgoMutex; + std::lock_guard lock(loadAlgoMutex); NCCLCHECK(mscclLoadAlgo(m.filePath.c_str(), &(status.rankToAlgoHandles[i][comm->rank]), comm->rank)); } // Connect algorithms @@ -241,7 +219,7 @@ ncclResult_t mscclInit(ncclComm_t comm) { } } - if (mscclInitialized.load(std::memory_order_acquire)) { + if (mscclInitialized(comm->rank)) { return ncclSuccess; } @@ -250,7 +228,7 @@ ncclResult_t mscclInit(ncclComm_t comm) { status.lastStream = nullptr; NCCLCHECK(mscclInitWorkFifoStatus(&(status.defaultWorkFifoStatus))); - mscclInitialized.store(true, std::memory_order_release); + mscclSetInitialized(comm->rank); } INFO(NCCL_INIT, "MSCCL: Initialization finished, localSize %ld", mscclKernMaxLocalSize()); @@ -266,8 +244,8 @@ ncclResult_t mscclGroupStart() { return ncclSuccess; } -static ncclResult_t mscclInternalSchedulerSelectAlgo(struct mscclSchedulerParam* param) { - mscclStatus& status = mscclGetStatus(); +static ncclResult_t mscclInternalSchedulerSelectAlgo(int rank, struct mscclSchedulerParam* param) { + mscclStatus& status = mscclGetStatus(rank); param->scheduled = false; // Current MSCCL doesn't support pre/post op @@ -312,13 +290,13 @@ static ncclResult_t mscclInternalSchedulerSelectAlgo(struct mscclSchedulerParam* } static ncclResult_t mscclSchedulerSelectAlgo(struct mscclSavedSchedulerParam* param) { - mscclStatus& status = mscclGetStatus(); + mscclStatus& status = mscclGetStatus(param->comm->rank); if (status.mscclSchedulerPtr) { NCCLCHECK(status.mscclSchedulerPtr->selectAlgo(&(param->p))); } else { // Disable MSCCL algorithms if machine type is not matching if (param->comm->topo->mscclEnabled || mscclForceEnabled()) { - NCCLCHECK(mscclInternalSchedulerSelectAlgo(&(param->p))); + NCCLCHECK(mscclInternalSchedulerSelectAlgo(param->comm->rank, &(param->p))); } else { param->p.scheduled = false; } @@ -513,12 +491,30 @@ ncclResult_t mscclGroupEnd() { return ncclSuccess; } -static ncclResult_t mscclInternalSchedulerTeardown() { +static ncclResult_t mscclInternalUnloadAlgo(int rank, mscclAlgoHandle_t mscclAlgoHandle) { + mscclStatus& status = mscclGetStatus(rank); + + free(status.hostAlgos[mscclAlgoHandle]); + status.hostAlgos.erase(mscclAlgoHandle); + + NCCLCHECK(ncclCudaFree(status.devAlgos[mscclAlgoHandle])); + status.devAlgos.erase(mscclAlgoHandle); + + status.freeAlgoHandles.push_back(mscclAlgoHandle); + + for (auto &s : status.connectedAlgos) { + s.second.erase(mscclAlgoHandle); + } + + return ncclSuccess; +} + +static ncclResult_t mscclInternalSchedulerTeardown(int rank) { ncclResult_t ret = ncclSuccess, tmpRet = ncclSuccess; - mscclStatus& status = mscclGetStatus(); + mscclStatus& status = mscclGetStatus(rank); for (auto &m : status.rankToAlgoHandles) { for (auto &p : m) { - tmpRet = mscclUnloadAlgo(p.second); + tmpRet = mscclInternalUnloadAlgo(rank, p.second); if (ret == ncclSuccess) { ret = tmpRet; } @@ -529,18 +525,13 @@ static ncclResult_t mscclInternalSchedulerTeardown() { return ret; } -ncclResult_t mscclTeardown() { - // Always teardown thread local status - mscclThreadLocalStatus threadLocalStatus = mscclGetThreadLocalStatus(); - threadLocalStatus.savedSchedulerParams.clear(); - +ncclResult_t mscclTeardown(int rank) { { - std::lock_guard lock(mscclLifecycleMutex); - - if (!mscclInitialized.load(std::memory_order_acquire)) { + if (!mscclInitialized(rank)) { + mscclRemoveRank(rank); return ncclSuccess; } - mscclStatus& status = mscclGetStatus(); + mscclStatus& status = mscclGetStatus(rank); for (auto &p : status.hostAlgos) { free(p.second); status.freeAlgoHandles.push_back(p.first); @@ -563,13 +554,14 @@ ncclResult_t mscclTeardown() { dlclose(status.mscclSchedulerLib); status.mscclSchedulerLib = nullptr; } else { - NCCLCHECK(mscclInternalSchedulerTeardown()); + NCCLCHECK(mscclInternalSchedulerTeardown(rank)); } NCCLCHECK(mscclDestroyWorkFifoStatus(&(status.defaultWorkFifoStatus))); for (auto &p : status.graphWorkFifoStatus) { NCCLCHECK(mscclDestroyWorkFifoStatus(&(p.second))); } - mscclInitialized.store(false, std::memory_order_release); + mscclSetInitialized(rank, false); + mscclRemoveRank(rank); } INFO(NCCL_INIT, "MSCCL: Teardown finished"); diff --git a/projects/rccl/src/misc/msccl/msccl_setup.cc b/projects/rccl/src/misc/msccl/msccl_setup.cc index 292aa88678..1068887067 100644 --- a/projects/rccl/src/misc/msccl/msccl_setup.cc +++ b/projects/rccl/src/misc/msccl/msccl_setup.cc @@ -22,10 +22,10 @@ static inline size_t computeSizeNeeded(size_t nBytes, int nScratchChunks, int nC return (nBytes * (size_t)nScratchChunks) / (size_t)nChunksPerLoop; } -ncclResult_t mscclGetCaptureStatus(hipStream_t stream) { - mscclStatus& status = mscclGetStatus(); +ncclResult_t mscclGetCaptureStatus(int rank, hipStream_t stream) { + mscclStatus& status = mscclGetStatus(rank); mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus(); - mscclSavedProxyArgs& savedProxyArgs = mscclGetSavedProxyArgs(); + mscclSavedProxyArgs& savedProxyArgs = mscclGetSavedProxyArgs(rank); cudaStreamCaptureStatus captureStatus; unsigned long long captureId; CUDACHECK(hipStreamGetCaptureInfo_v2(stream, &captureStatus, &captureId, &threadLocalStatus.graph, nullptr, nullptr)); @@ -42,12 +42,12 @@ ncclResult_t mscclGetCaptureStatus(hipStream_t stream) { } else { threadLocalStatus.captureStatus = mscclNoCapture; } - INFO(NCCL_NET,"mscclGetCaptureStatus: %d, captureId: %llu, size: %lu\n", threadLocalStatus.captureStatus, threadLocalStatus.captureId, mscclGetSavedProxyArgs()[captureId].size()); + INFO(NCCL_NET,"mscclGetCaptureStatus: %d, captureId: %llu, size: %lu\n", threadLocalStatus.captureStatus, threadLocalStatus.captureId, savedProxyArgs[captureId].size()); return ncclSuccess; } ncclResult_t mscclSetupCount(struct mscclAlgo* hostAlgo, ncclComm_t comm, size_t count, ncclDataType_t dataType) { - mscclStatus& status = mscclGetStatus(); + mscclStatus& status = mscclGetStatus(comm->rank); status.stepSize = comm->buffSizes[hostAlgo->protocol] / NCCL_STEPS; status.chunkSteps = hostAlgo->protocol == NCCL_PROTO_SIMPLE ? hostAlgo->chunkSteps : 1; status.sliceSteps = hostAlgo->protocol == NCCL_PROTO_SIMPLE ? hostAlgo->sliceSteps : 1; @@ -69,12 +69,11 @@ ncclResult_t mscclSetupCount(struct mscclAlgo* hostAlgo, ncclComm_t comm, size_t } ncclResult_t mscclSetupScratch(struct mscclAlgo* hostAlgo, hipStream_t stream) { - mscclStatus& status = mscclGetStatus(); return ncclSuccess; } -ncclResult_t mscclSetupSyncFlags(hipStream_t stream) { - mscclStatus& status = mscclGetStatus(); +ncclResult_t mscclSetupSyncFlags(int rank, hipStream_t stream) { + mscclStatus& status = mscclGetStatus(rank); mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus(); if (threadLocalStatus.captureStatus == mscclNewCapture || status.workIndex > (1ULL << (8*sizeof(status.workIndex))) - 2 * NCCL_MAX_OPS - 1) { @@ -86,7 +85,7 @@ ncclResult_t mscclSetupSyncFlags(hipStream_t stream) { } ncclResult_t mscclSetupConnections(struct mscclAlgo* hostAlgo, ncclComm_t comm) { - mscclStatus& status = mscclGetStatus(); + mscclStatus& status = mscclGetStatus(comm->rank); // Check whether there are enough channels if (hostAlgo->nChannels > comm->nChannels) { @@ -124,7 +123,7 @@ ncclResult_t mscclSetupConnections(struct mscclAlgo* hostAlgo, ncclComm_t comm) } static ncclResult_t mscclSetupProxyImpl(struct mscclAlgo* hostAlgo, ncclComm_t comm) { - mscclStatus& status = mscclGetStatus(); + mscclStatus& status = mscclGetStatus(comm->rank); mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus(); struct ncclProxyOp proxyOp = {}; proxyOp.connIndex = 0; @@ -185,9 +184,9 @@ static void HIPRT_CB mscclSetupProxyCallback(void *args) { } ncclResult_t mscclSetupProxy(struct mscclAlgo* hostAlgo, ncclComm_t comm, hipStream_t stream) { - mscclStatus& status = mscclGetStatus(); + mscclStatus& status = mscclGetStatus(comm->rank); mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus(); - mscclSavedProxyArgs& savedProxyArgs = mscclGetSavedProxyArgs(); + mscclSavedProxyArgs& savedProxyArgs = mscclGetSavedProxyArgs(comm->rank); if (threadLocalStatus.captureStatus == mscclNoCapture) { INFO(NCCL_NET,"mscclSetupProxy: no capture\n"); NCCLCHECK(mscclSetupProxyImpl(hostAlgo, comm)); @@ -203,7 +202,7 @@ ncclResult_t mscclSetupProxy(struct mscclAlgo* hostAlgo, ncclComm_t comm, hipStr p.userData = params; CUDACHECK(hipGraphAddHostNode(&callbackNode, threadLocalStatus.graph, nullptr, 0, &p)); } - mscclGetSavedProxyArgs()[threadLocalStatus.captureId].emplace_back(hostAlgo, comm); + savedProxyArgs[threadLocalStatus.captureId].emplace_back(hostAlgo, comm); } return ncclSuccess; } @@ -410,7 +409,7 @@ RCCL_PARAM(MscclForceFullOps, "MSCCL_FORCE_FULLOPS", 0); ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count, ncclDataType_t dataType, ncclRedOp_t op, struct mscclAlgo* hostAlgo, struct mscclAlgo* devAlgo, ncclComm_t comm, hipStream_t stream) { - mscclStatus& status = mscclGetStatus(); + mscclStatus& status = mscclGetStatus(comm->rank); mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus(); if (status.lastStream != stream && status.lastStream != nullptr) { diff --git a/projects/rccl/src/misc/msccl/msccl_status.cc b/projects/rccl/src/misc/msccl/msccl_status.cc index 37a1641efd..3812112dd6 100644 --- a/projects/rccl/src/misc/msccl/msccl_status.cc +++ b/projects/rccl/src/misc/msccl/msccl_status.cc @@ -6,9 +6,60 @@ #include "msccl/msccl_status.h" #include "msccl/msccl_struct.h" -mscclStatus& mscclGetStatus() { - static mscclStatus status; - return status; +#include "debug.h" + +#include +#include +#include +using namespace std; + +struct mscclRankState { + int rank; + bool initialized; + mscclStatus status; + mscclSavedProxyArgs savedProxyArgs; + + mscclRankState() : rank(-1), initialized(false), status(), savedProxyArgs() {} + explicit mscclRankState(const mscclRankState&) = default; +}; + +static mutex rankStatesMutex; +static unordered_map> rankStates; + +static inline mscclRankState& mscclGetRankState(int rank) { + static thread_local shared_ptr threadRankState = make_shared(); + + if (rank < 0) { + return *threadRankState; + } + + lock_guard lock(rankStatesMutex); + + auto rankStateIt = rankStates.find(rank); + if (rankStateIt == rankStates.end()) { + rankStateIt = rankStates.insert(make_pair(rank, make_shared(*threadRankState))).first; + rankStateIt->second->rank = rank; + } + return *(rankStateIt->second); +} + +bool mscclInitialized(int rank) { + return mscclGetRankState(rank).initialized; +} + +void mscclSetInitialized(int rank, bool initialized) { + auto& state = mscclGetRankState(rank); + assert(!initialized || !state.initialized); + state.initialized = initialized; +} + +void mscclRemoveRank(int rank) { + lock_guard lock(rankStatesMutex); + rankStates.erase(rank); +} + +mscclStatus& mscclGetStatus(int rank) { + return mscclGetRankState(rank).status; } mscclThreadLocalStatus& mscclGetThreadLocalStatus() { @@ -16,7 +67,6 @@ mscclThreadLocalStatus& mscclGetThreadLocalStatus() { return threadLocalStatus; } -mscclSavedProxyArgs& mscclGetSavedProxyArgs() { - static mscclSavedProxyArgs savedProxyArgs; - return savedProxyArgs; +mscclSavedProxyArgs& mscclGetSavedProxyArgs(int rank) { + return mscclGetRankState(rank).savedProxyArgs; } diff --git a/projects/rccl/src/msccl.cc b/projects/rccl/src/msccl.cc index fd33c94961..e74e9a8b09 100644 --- a/projects/rccl/src/msccl.cc +++ b/projects/rccl/src/msccl.cc @@ -12,7 +12,7 @@ NCCL_API(ncclResult_t, mscclLoadAlgo, const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, int rank); ncclResult_t mscclLoadAlgo(const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, int rank) { - mscclStatus& status = mscclGetStatus(); + mscclStatus& status = mscclGetStatus(rank); if (status.freeAlgoHandles.size() == 0) { WARN("MSCCL: MSCCL_MAX_NUM_ALGOS (%d) limit reached", MSCCL_MAX_NUM_ALGOS); @@ -56,17 +56,17 @@ ncclResult_t mscclRunAlgo( NvtxParamsMsccl payload{sendCounts[comm->rank] * ncclTypeSize(dataType), recvCounts[comm->rank] * ncclTypeSize(dataType)}; NVTX3_FUNC_WITH_PARAMS(MSCCL, MscclSchema, payload) - mscclStatus& status = mscclGetStatus(); + mscclStatus& status = mscclGetStatus(comm->rank); struct mscclAlgo* hostAlgo = status.hostAlgos[mscclAlgoHandle]; struct mscclAlgo* devAlgo = status.devAlgos[mscclAlgoHandle]; - NCCLCHECK(mscclGetCaptureStatus(stream)); + NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream)); NCCLCHECK(mscclSetupCount(hostAlgo, comm, count, dataType)); NCCLCHECK(mscclSetupScratch(hostAlgo, stream)); - NCCLCHECK(mscclSetupSyncFlags(stream)); + NCCLCHECK(mscclSetupSyncFlags(comm->rank, stream)); NCCLCHECK(mscclSetupProxy(hostAlgo, comm, stream)); @@ -77,19 +77,6 @@ ncclResult_t mscclRunAlgo( NCCL_API(ncclResult_t, mscclUnloadAlgo, mscclAlgoHandle_t mscclAlgoHandle); ncclResult_t mscclUnloadAlgo(mscclAlgoHandle_t mscclAlgoHandle) { - mscclStatus& status = mscclGetStatus(); - - free(status.hostAlgos[mscclAlgoHandle]); - status.hostAlgos.erase(mscclAlgoHandle); - - NCCLCHECK(ncclCudaFree(status.devAlgos[mscclAlgoHandle])); - status.devAlgos.erase(mscclAlgoHandle); - - status.freeAlgoHandles.push_back(mscclAlgoHandle); - - for (auto &s : status.connectedAlgos) { - s.second.erase(mscclAlgoHandle); - } - + // deprecated return ncclSuccess; } diff --git a/projects/rccl/src/nccl.h.in b/projects/rccl/src/nccl.h.in index c324319e7c..1d127b0b4f 100644 --- a/projects/rccl/src/nccl.h.in +++ b/projects/rccl/src/nccl.h.in @@ -766,6 +766,7 @@ ncclResult_t pmscclRunAlgo( /*! @endcond */ /*! @brief MSCCL Unload Algorithm + @deprecated This function has been removed from the public API. @details Unload MSCCL algorithm previous loaded using its handle. This API is expected to be called by MSCCL scheduler instead of end users. @return Result code. See @ref rccl_result_code for more details. diff --git a/projects/rccl/test/CMakeLists.txt b/projects/rccl/test/CMakeLists.txt index 78b3f0d3ba..7fcbea1c08 100644 --- a/projects/rccl/test/CMakeLists.txt +++ b/projects/rccl/test/CMakeLists.txt @@ -5,6 +5,8 @@ cmake_minimum_required(VERSION 2.8.12) if(BUILD_TESTS) + option(OPENMP_TESTS_ENABLED "Enable OpenMP for unit tests" OFF) + message("Building rccl unit tests (Installed in /test/rccl-UnitTests)") find_package(hsa-runtime64 PATHS /opt/rocm ) @@ -23,6 +25,10 @@ if(BUILD_TESTS) find_library(ROCR_LIB ${CORE_RUNTIME_TARGET} PATHS ${ROCR_LIB_DIR} "/opt/rocm" PATH_SUFFIXES lib lib64 REQUIRED) endif() + if(OPENMP_TESTS_ENABLED) + find_package(OpenMP REQUIRED) + endif() + include_directories(${GTEST_INCLUDE_DIRS} ./common) # Collect testing framework source files @@ -60,12 +66,23 @@ if(BUILD_TESTS) if(LL128_ENABLED) target_compile_definitions(rccl-UnitTests PRIVATE ENABLE_LL128) endif() + if(OPENMP_TESTS_ENABLED) + target_compile_definitions(rccl-UnitTests PRIVATE ENABLE_OPENMP) + endif() target_compile_definitions(rccl-UnitTests PRIVATE ROCM_PATH="${ROCM_PATH}") + ## Set rccl-UnitTests compile definitions + if(OPENMP_TESTS_ENABLED) + target_compile_options(rccl-UnitTests PRIVATE "${OpenMP_CXX_FLAGS}") + endif() + ## Set rccl-UnitTests linked libraries target_link_libraries(rccl-UnitTests PRIVATE ${GTEST_BOTH_LIBRARIES}) target_link_libraries(rccl-UnitTests PRIVATE hip::host hip::device hsa-runtime64::hsa-runtime64) target_link_libraries(rccl-UnitTests PRIVATE Threads::Threads) + if(OPENMP_TESTS_ENABLED) + target_link_libraries(rccl-UnitTests PRIVATE "${OpenMP_CXX_FLAGS}") + endif() # rccl-UnitTests using static library of rccl requires passing rccl # through -l and -L instead of command line input. diff --git a/projects/rccl/test/common/EnvVars.cpp b/projects/rccl/test/common/EnvVars.cpp index d8cc5e8bef..c51b7be012 100644 --- a/projects/rccl/test/common/EnvVars.cpp +++ b/projects/rccl/test/common/EnvVars.cpp @@ -108,6 +108,7 @@ namespace RcclUnitTesting showTiming = GetEnvVar("UT_SHOW_TIMING", 1); useInteractive = GetEnvVar("UT_INTERACTIVE", 0); timeoutUs = GetEnvVar("UT_TIMEOUT_US" , 5000000); + useMultithreading = GetEnvVar("UT_MULTITHREAD", false); // Total number of reduction ops int numOps = ncclNumOps; @@ -232,7 +233,8 @@ namespace RcclUnitTesting std::make_tuple("UT_PRINT_VALUES" , printValues , "Print array values (-1 for all)"), std::make_tuple("UT_SHOW_TIMING" , showTiming , "Show timing table"), std::make_tuple("UT_INTERACTIVE" , useInteractive, "Run in interactive mode"), - std::make_tuple("UT_TIMEOUT_US" , timeoutUs , "Timeout limit for collective calls in us") + std::make_tuple("UT_TIMEOUT_US" , timeoutUs , "Timeout limit for collective calls in us"), + std::make_tuple("UT_MULTITHREAD" , useMultithreading, "Multi-thread single-process ranks"), }; printf("================================================================================\n"); diff --git a/projects/rccl/test/common/EnvVars.hpp b/projects/rccl/test/common/EnvVars.hpp index bf54611fc0..ea122e5971 100644 --- a/projects/rccl/test/common/EnvVars.hpp +++ b/projects/rccl/test/common/EnvVars.hpp @@ -29,6 +29,7 @@ namespace RcclUnitTesting bool showTiming; // Show timing per case at end [UT_SHOW_TIMING] bool useInteractive; // Run in interactive mode [UT_INTERACTIVE] int timeoutUs; // Set timeout for child in microseconds [UT_TIMEOUT_US] + bool useMultithreading; // Multi-thread single-process ranks [UT_MULTITHREAD] bool isGfx94; // Detects if architecture is gfx94 // Constructor that parses and collects environment variables diff --git a/projects/rccl/test/common/ErrCode.hpp b/projects/rccl/test/common/ErrCode.hpp index f8ab991b41..0eebf42e00 100644 --- a/projects/rccl/test/common/ErrCode.hpp +++ b/projects/rccl/test/common/ErrCode.hpp @@ -8,7 +8,7 @@ namespace RcclUnitTesting { - typedef enum + typedef enum : int { TEST_SUCCESS = 0, TEST_FAIL = 1, @@ -17,25 +17,41 @@ namespace RcclUnitTesting #define ERROR(...) printf("\033[0;31m" "[ ERROR ] " "\033[0m" __VA_ARGS__) #define INFO(...) printf("[ INFO ] " __VA_ARGS__) +#define WARN(...) printf("[ WARNING ] " __VA_ARGS__) +#define RETURN_RESULT(result) return (result) -#define CHECK_CALL(func) \ - { \ +#define CHECK_CALL_BASE(func, RESULT, RESULT_ARGS...) \ + do { \ ErrCode status = func; \ if (status != TEST_SUCCESS) \ { \ ERROR("Error in call %s\n", #func); \ - return status; \ + RESULT(status, ##RESULT_ARGS); \ } \ - } + } while (false) +#define CHECK_CALL(func) CHECK_CALL_BASE(func, RETURN_RESULT) -#define CHECK_HIP(func) \ - { \ +#define CHECK_HIP_BASE(func, RESULT, RESULT_ARGS...) \ + do { \ hipError_t error = (func); \ if (error != hipSuccess) \ { \ fprintf(stderr, "\033[0;31m" "[ ERROR ] HIP error: %s File:%s Line:%d\n" "\033[m", \ hipGetErrorString(error), strrchr("/" __FILE__, '/') + 1, __LINE__); \ - return TEST_FAIL; \ + RESULT(TEST_FAIL, ##RESULT_ARGS); \ } \ - } + } while (false) +#define CHECK_HIP(func) CHECK_HIP_BASE(func, RETURN_RESULT) + +#ifdef ENABLE_OPENMP +#define OMP_CANCEL_FOR(result, errCode) errCode = (result); _Pragma("omp cancel for") +#define RANK_RESULT(errCode, result) OMP_CANCEL_FOR(result, errCode) +#define CHECK_CALL_RANK(errCode, func) CHECK_CALL_BASE(func, OMP_CANCEL_FOR, errCode) +#define CHECK_HIP_RANK(errCode, func) CHECK_HIP_BASE(func, OMP_CANCEL_FOR, errCode) +#else +#define RANK_RESULT(errCode, result) RETURN_RESULT(result) +#define CHECK_CALL_RANK(errCode, func) CHECK_CALL(func) +#define CHECK_HIP_RANK(errCode, func) CHECK_HIP(func) +#endif } + diff --git a/projects/rccl/test/common/TestBed.cpp b/projects/rccl/test/common/TestBed.cpp index 251ebfd8ae..84b4b74ac5 100644 --- a/projects/rccl/test/common/TestBed.cpp +++ b/projects/rccl/test/common/TestBed.cpp @@ -97,7 +97,7 @@ namespace RcclUnitTesting childList.resize(this->numActiveChildren); for (int childId = 0; childId < this->numActiveChildren; ++childId) { - childList[childId] = new TestBedChild(childId, ev.verbose, ev.printValues); + childList[childId] = new TestBedChild(childId, ev.verbose, ev.printValues, ev.useMultithreading); if (childList[childId]->InitPipes() != TEST_SUCCESS) { ERROR("Unable to create pipes to child process\n"); diff --git a/projects/rccl/test/common/TestBedChild.cpp b/projects/rccl/test/common/TestBedChild.cpp index cd9b8de8b3..892585efe7 100644 --- a/projects/rccl/test/common/TestBedChild.cpp +++ b/projects/rccl/test/common/TestBedChild.cpp @@ -8,20 +8,33 @@ #include #include +#ifdef ENABLE_OPENMP +#include +#endif -#define CHILD_NCCL_CALL(cmd, msg) \ - { \ +static int getThreadId() +{ + #ifdef ENABLE_OPENMP + return (int)omp_get_thread_num(); + #else + return -1; + #endif +} + +#define CHILD_NCCL_CALL_BASE(cmd, msg, RESULT, RESULT_ARGS...) \ + do { \ if (this->verbose) printf("[ NCCL CALL] " #cmd "\n"); \ ncclResult_t status = cmd; \ if (status != ncclSuccess) \ { \ ERROR("Child process %d fails NCCL call %s with code %d\n", this->childId, msg, status); \ - return TEST_FAIL; \ + RESULT(TEST_FAIL, ##RESULT_ARGS); \ } \ - } + } while (false) +#define CHILD_NCCL_CALL(cmd, msg) CHILD_NCCL_CALL_BASE(cmd, msg, RETURN_RESULT) -#define CHILD_NCCL_CALL_NON_BLOCKING(msg, localRank) \ - { \ +#define CHILD_NCCL_CALL_NON_BLOCKING_BASE(msg, localRank, RESULT, RESULT_ARGS...) \ + do { \ unsigned long int loop_counter = 0; \ ncclResult_t ncclAsyncErr; \ loop_counter = 0; \ @@ -34,20 +47,30 @@ if (ncclAsyncErr != ncclSuccess) \ { \ ERROR("Child process %d fails NCCL call %s with code %d\n", this->childId, msg, ncclAsyncErr); \ - return TEST_FAIL; \ + RESULT(TEST_FAIL, ##RESULT_ARGS); \ } \ - } + } while (false) +#define CHILD_NCCL_CALL_NON_BLOCKING(msg, localRank) CHILD_NCCL_CALL_NON_BLOCKING_BASE(msg, localRank, RETURN_RESULT) #define PIPE_READ(val) \ if (read(childReadFd, &val, sizeof(val)) != sizeof(val)) return TEST_FAIL; +#ifdef ENABLE_OPENMP +#define CHILD_NCCL_CALL_RANK(errCode, cmd, msg) CHILD_NCCL_CALL_BASE(cmd, msg, OMP_CANCEL_FOR, errCode) +#define CHILD_NCCL_CALL_NON_BLOCKING_RANK(errCode, msg, localRank) CHILD_NCCL_CALL_NON_BLOCKING_BASE(msg, localRank, OMP_CANCEL_FOR, errCode) +#else +#define CHILD_NCCL_CALL_RANK(errCode, cmd, msg) CHILD_NCCL_CALL(cmd, msg) +#define CHILD_NCCL_CALL_NON_BLOCKING_RANK(errCode, msg, localRank) CHILD_NCCL_CALL_NON_BLOCKING(msg, localRank) +#endif + namespace RcclUnitTesting { - TestBedChild::TestBedChild(int const childId, bool const verbose, int const printValues) + TestBedChild::TestBedChild(int const childId, bool const verbose, int const printValues, bool const useRankThreading) { this->childId = childId; this->verbose = verbose; this->printValues = printValues; + this->useRankThreading = useRankThreading; } int TestBedChild::InitPipes() @@ -83,6 +106,9 @@ namespace RcclUnitTesting // Wait for commands from parent process if (verbose) INFO("Child %d enters execution loop\n", this->childId); + #ifndef ENABLE_OPENMP + if (verbose && useRankThreading) WARN("Multi-threaded ranks requires ENABLE_OPENMP to be defined\n"); + #endif int command; while (read(childReadFd, &command, sizeof(command)) > 0) { @@ -473,6 +499,8 @@ namespace RcclUnitTesting } } + int numThreadsToUse = this->useRankThreading ? numRanksToExecute : 1; + // Start group call CHILD_NCCL_CALL(ncclGroupStart(), "ncclGroupStart"); @@ -480,9 +508,17 @@ namespace RcclUnitTesting for (int collId = 0; collId < this->numCollectivesInGroup[groupId]; ++collId) { // Loop over all local ranks + if (this->verbose && this->useRankThreading) + INFO("Group %d collective %d running %d threads\n", groupId, collId, numThreadsToUse); + ErrCode errCode = TEST_SUCCESS; + auto& errCodeVal = reinterpret_cast(errCode); + #pragma omp parallel for num_threads(numThreadsToUse) reduction(max : errCodeVal) for (int localRank : localRanksToExecute) { - CHECK_HIP(hipSetDevice(this->deviceIds[localRank])); + if (this->verbose && this->useRankThreading) + INFO("Group %d collective %d running rank %d on thread %d\n", groupId, collId, localRank, getThreadId()); + + CHECK_HIP_RANK(errCode, hipSetDevice(this->deviceIds[localRank])); CollectiveArgs const& collArg = this->collArgs[groupId][localRank][collId]; @@ -492,14 +528,14 @@ namespace RcclUnitTesting PtrUnion inputCpu; size_t const numInputBytes = numInputElementsToPrint * DataTypeToBytes(collArg.dataType); inputCpu.AllocateCpuMem(numInputBytes); - CHECK_HIP(hipMemcpy(inputCpu.ptr, collArg.inputGpu.ptr, numInputBytes, hipMemcpyDeviceToHost)); + CHECK_HIP_RANK(errCode, hipMemcpy(inputCpu.ptr, collArg.inputGpu.ptr, numInputBytes, hipMemcpyDeviceToHost)); printf("[ DEBUG ] Rank %02d Group %d Coll %d %-10s: %s\n", collArg.globalRank, groupId, collId, "Input", inputCpu.ToString(collArg.dataType, numInputElementsToPrint).c_str()); inputCpu.FreeCpuMem(); int const numOutputElementsToPrint = (this->printValues < 0 ? collArg.numOutputElements : this->printValues); size_t const numOutputBytes = numOutputElementsToPrint * DataTypeToBytes(collArg.dataType); - CHECK_HIP(hipMemcpy(collArg.outputCpu.ptr, collArg.outputGpu.ptr, numOutputBytes, hipMemcpyDeviceToHost)); + CHECK_HIP_RANK(errCode, hipMemcpy(collArg.outputCpu.ptr, collArg.outputGpu.ptr, numOutputBytes, hipMemcpyDeviceToHost)); printf("[ DEBUG ] Rank %02d Group %d Coll %d %-10s: %s\n", collArg.globalRank, groupId, collId, "Pre-Output", collArg.outputCpu.ToString(collArg.dataType, numOutputElementsToPrint).c_str()); } @@ -507,7 +543,8 @@ namespace RcclUnitTesting switch (collArg.funcType) { case ncclCollBroadcast: - CHILD_NCCL_CALL(ncclBroadcast(collArg.inputGpu.ptr, + CHILD_NCCL_CALL_RANK(errCode, ncclBroadcast( + collArg.inputGpu.ptr, collArg.outputGpu.ptr, collArg.numInputElements, collArg.dataType, @@ -517,7 +554,8 @@ namespace RcclUnitTesting "ncclBroadcast"); break; case ncclCollReduce: - CHILD_NCCL_CALL(ncclReduce(collArg.inputGpu.ptr, + CHILD_NCCL_CALL_RANK(errCode, ncclReduce( + collArg.inputGpu.ptr, collArg.outputGpu.ptr, collArg.numInputElements, collArg.dataType, @@ -528,7 +566,8 @@ namespace RcclUnitTesting "ncclReduce"); break; case ncclCollAllGather: - CHILD_NCCL_CALL(ncclAllGather(collArg.inputGpu.ptr, + CHILD_NCCL_CALL_RANK(errCode, ncclAllGather( + collArg.inputGpu.ptr, collArg.outputGpu.ptr, collArg.numInputElements, collArg.dataType, @@ -537,7 +576,8 @@ namespace RcclUnitTesting "ncclAllGather"); break; case ncclCollReduceScatter: - CHILD_NCCL_CALL(ncclReduceScatter(collArg.inputGpu.ptr, + CHILD_NCCL_CALL_RANK(errCode, ncclReduceScatter( + collArg.inputGpu.ptr, collArg.outputGpu.ptr, collArg.numOutputElements, collArg.dataType, @@ -547,7 +587,8 @@ namespace RcclUnitTesting "ncclReduceScatter"); break; case ncclCollAllReduce: - CHILD_NCCL_CALL(ncclAllReduce(collArg.inputGpu.ptr, + CHILD_NCCL_CALL_RANK(errCode, ncclAllReduce( + collArg.inputGpu.ptr, collArg.outputGpu.ptr, collArg.numInputElements, collArg.dataType, @@ -557,7 +598,8 @@ namespace RcclUnitTesting "ncclAllReduce"); break; case ncclCollGather: - CHILD_NCCL_CALL(ncclGather(collArg.inputGpu.ptr, + CHILD_NCCL_CALL_RANK(errCode, ncclGather( + collArg.inputGpu.ptr, collArg.outputGpu.ptr, collArg.numInputElements, collArg.dataType, @@ -567,7 +609,8 @@ namespace RcclUnitTesting "ncclGather"); break; case ncclCollScatter: - CHILD_NCCL_CALL(ncclScatter(collArg.inputGpu.ptr, + CHILD_NCCL_CALL_RANK(errCode, ncclScatter( + collArg.inputGpu.ptr, collArg.outputGpu.ptr, collArg.numOutputElements, collArg.dataType, @@ -577,7 +620,8 @@ namespace RcclUnitTesting "ncclScatter"); break; case ncclCollAllToAll: - CHILD_NCCL_CALL(ncclAllToAll(collArg.inputGpu.ptr, + CHILD_NCCL_CALL_RANK(errCode, ncclAllToAll( + collArg.inputGpu.ptr, collArg.outputGpu.ptr, collArg.numInputElements / collArg.totalRanks, collArg.dataType, @@ -586,7 +630,8 @@ namespace RcclUnitTesting "ncclAllToAll"); break; case ncclCollAllToAllv: - CHILD_NCCL_CALL(ncclAllToAllv(collArg.inputGpu.ptr, + CHILD_NCCL_CALL_RANK(errCode, ncclAllToAllv( + collArg.inputGpu.ptr, collArg.options.sendcounts + (this->rankOffset + localRank)*this->totalRanks, collArg.options.sdispls + (this->rankOffset + localRank)*this->totalRanks, collArg.outputGpu.ptr, @@ -598,7 +643,8 @@ namespace RcclUnitTesting "ncclAllToAllv"); break; case ncclCollSend: - CHILD_NCCL_CALL(ncclSend(collArg.inputGpu.ptr, + CHILD_NCCL_CALL_RANK(errCode, ncclSend( + collArg.inputGpu.ptr, collArg.numInputElements, collArg.dataType, collArg.options.root, @@ -607,7 +653,8 @@ namespace RcclUnitTesting "ncclSend"); break; case ncclCollRecv: - CHILD_NCCL_CALL(ncclRecv(collArg.outputGpu.ptr, + CHILD_NCCL_CALL_RANK(errCode, ncclRecv( + collArg.outputGpu.ptr, collArg.numOutputElements, collArg.dataType, collArg.options.root, @@ -617,14 +664,18 @@ namespace RcclUnitTesting break; default: ERROR("Unknown func type %d\n", collArg.funcType); - return TEST_FAIL; + RANK_RESULT(errCode, TEST_FAIL); } if (this->useBlocking == false) { - CHILD_NCCL_CALL_NON_BLOCKING("ncclCommGetAsyncErrorExecuteCollectives", localRank); + CHILD_NCCL_CALL_NON_BLOCKING_RANK(errCode, "ncclCommGetAsyncErrorExecuteCollectives", localRank); } + + if (this->verbose && this->useRankThreading) + INFO("Group %d collective %d done rank %d on thread %d\n", groupId, collId, localRank, getThreadId()); } + if (this->useRankThreading) CHECK_CALL(errCode); } // End group call if (this->useBlocking == false) diff --git a/projects/rccl/test/common/TestBedChild.hpp b/projects/rccl/test/common/TestBedChild.hpp index fdb65964ba..34bef5cebf 100644 --- a/projects/rccl/test/common/TestBedChild.hpp +++ b/projects/rccl/test/common/TestBedChild.hpp @@ -57,6 +57,7 @@ namespace RcclUnitTesting pid_t pid; bool verbose; int printValues; + bool useRankThreading; // Pipes used to communicate between parent process int parentWriteFd; @@ -80,7 +81,7 @@ namespace RcclUnitTesting std::vector>> graphEnabled; // Constructor - TestBedChild(int const childId, bool const verbose, int const printValues); + TestBedChild(int const childId, bool const verbose, int const printValues, bool const useRankThreading); // Prepare parent/child communication pipes - to be executed by parent process int InitPipes();