Enable multi-threading for MSCCL (#1203)
MSCCL can now run in a multi-threaded configuration. To test in the unit tests, added the ENABLE_OPENMP compile definition flag and the --openmp-test-enable flag to the unit test build script. To activate, set the environment variables UT_MULTITHREADED=1 and UT_PROCESS_MASK=1. Set Jenkins to use this mode.
Этот коммит содержится в:
коммит произвёл
GitHub
родитель
45f3fbc52f
Коммит
0c36d571ea
@@ -24,7 +24,7 @@ def runTestCommand (platform, project, gfilter, envars)
|
||||
cd ${project.paths.project_build_prefix}/build/release/test
|
||||
${sudo} ulimit -l unlimited
|
||||
ulimit -a
|
||||
${sudo} ${envars} RCCL_ENABLE_SIGNALHANDLER=1 NCCL_DEBUG=INFO HSA_FORCE_FINE_GRAIN_PCIE=1 ./rccl-UnitTests --gtest_filter=${gfilter} --gtest_output=xml --gtest_color=yes
|
||||
${sudo} ${envars} RCCL_ENABLE_SIGNALHANDLER=1 NCCL_DEBUG=INFO HSA_FORCE_FINE_GRAIN_PCIE=1 UT_MULTITHREAD=1 UT_PROCESS_MASK=1 ./rccl-UnitTests --gtest_filter=${gfilter} --gtest_output=xml --gtest_color=yes
|
||||
"""
|
||||
|
||||
platform.runCommand(this, command)
|
||||
|
||||
+9
-1
@@ -26,6 +26,7 @@ install_prefix="${ROCM_PATH}"
|
||||
msccl_kernel_enabled=true
|
||||
num_parallel_jobs=$(nproc)
|
||||
npkit_enabled=false
|
||||
openmp_test_enabled=false
|
||||
roctx_enabled=false
|
||||
run_tests=false
|
||||
run_tests_all=false
|
||||
@@ -52,6 +53,7 @@ function display_help()
|
||||
echo " --amdgpu_targets Only compile for specified GPU architecture(s). For multiple targets, seperate by ';' (builds for all supported GPU architectures by default)"
|
||||
echo " --no_clean Don't delete files if they already exist"
|
||||
echo " --npkit-enable Compile with npkit enabled"
|
||||
echo " --openmp-test-enable Enable OpenMP in rccl unit tests"
|
||||
echo " --roctx-enable Compile with roctx enabled (example usage: rocprof --roctx-trace ./rccl-program)"
|
||||
echo " -p|--package_build Build RCCL package"
|
||||
echo " --prefix Specify custom directory to install RCCL to (default: \`/opt/rocm\`)"
|
||||
@@ -71,7 +73,7 @@ function display_help()
|
||||
# check if we have a modern version of getopt that can handle whitespace and long parameters
|
||||
getopt -T
|
||||
if [[ "$?" -eq 4 ]]; then
|
||||
GETOPT_PARSE=$(getopt --name "${0}" --options dfhij:lprt --longoptions address-sanitizer,dependencies,debug,enable_backtrace,disable-colltrace,disable-msccl-kernel,fast,help,install,jobs:,local_gpu_only,amdgpu_targets:,no_clean,npkit-enable,roctx-enable,package_build,prefix:,rm-legacy-include-dir,run_tests_all,run_tests_quick,static,tests_build,time-trace,verbose -- "$@")
|
||||
GETOPT_PARSE=$(getopt --name "${0}" --options dfhij:lprt --longoptions address-sanitizer,dependencies,debug,enable_backtrace,disable-colltrace,disable-msccl-kernel,fast,help,install,jobs:,local_gpu_only,amdgpu_targets:,no_clean,npkit-enable,openmp-test-enable,roctx-enable,package_build,prefix:,rm-legacy-include-dir,run_tests_all,run_tests_quick,static,tests_build,time-trace,verbose -- "$@")
|
||||
else
|
||||
echo "Need a new version of getopt"
|
||||
exit 1
|
||||
@@ -100,6 +102,7 @@ while true; do
|
||||
--amdgpu_targets) build_amdgpu_targets=${2}; shift 2 ;;
|
||||
--no_clean) clean_build=false; shift ;;
|
||||
--npkit-enable) npkit_enabled=true; shift ;;
|
||||
--openmp-test-enable) openmp_test_enabled=true; shift ;;
|
||||
--roctx-enable) roctx_enabled=true; shift ;;
|
||||
-p | --package_build) build_package=true; shift ;;
|
||||
--prefix) install_library=true; install_prefix=${2}; shift 2 ;;
|
||||
@@ -246,6 +249,11 @@ if [[ "${roctx_enabled}" == true ]]; then
|
||||
cmake_common_options="${cmake_common_options} -DROCTX=ON"
|
||||
fi
|
||||
|
||||
# Enable OpenMP in unit tests
|
||||
if [[ "${openmp_test_enabled}" == true ]]; then
|
||||
cmake_common_options="${cmake_common_options} -DOPENMP_TESTS_ENABLED=ON"
|
||||
fi
|
||||
|
||||
# Enable NPKit
|
||||
npkit_options=""
|
||||
if [[ "${npkit_enabled}" == true ]]; then
|
||||
|
||||
+11
-11
@@ -23,7 +23,7 @@ ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcoun
|
||||
size_t msgsize = sendcount * ncclTypeSize(datatype);
|
||||
NVTX3_FUNC_WITH_PARAMS(AllGather, AllGatherSchema, msgsize)
|
||||
|
||||
if (mscclAvailable() && !mscclIsCaller()) {
|
||||
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
||||
sendcount, datatype, 0, 0, ncclSum, mscclFuncAllGather, comm, stream);
|
||||
@@ -52,7 +52,7 @@ ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count,
|
||||
NvtxParamsAllReduce payload{count * ncclTypeSize(datatype), op};
|
||||
NVTX3_FUNC_WITH_PARAMS(AllReduce, AllReduceSchema, payload)
|
||||
|
||||
if (mscclAvailable() && !mscclIsCaller()) {
|
||||
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
||||
count, datatype, 0, 0, op, mscclFuncAllReduce, comm, stream);
|
||||
@@ -75,7 +75,7 @@ ncclResult_t ncclAllToAll(const void* sendbuff, void* recvbuff, size_t count, nc
|
||||
size_t msgsize = count * ncclTypeSize(datatype);
|
||||
NVTX3_FUNC_WITH_PARAMS(AllToAll, AllToAllSchema, msgsize)
|
||||
|
||||
if (mscclAvailable() && !mscclIsCaller()) {
|
||||
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
||||
count, datatype, 0, 0, ncclSum, mscclFuncAllToAll, comm, stream);
|
||||
@@ -122,7 +122,7 @@ ncclResult_t ncclAllToAllv(const void *sendbuff, const size_t sendcounts[], cons
|
||||
NvtxParamsAllToAllv payload{sendcounts[comm->rank] * ncclTypeSize(datatype), recvcounts[comm->rank] * ncclTypeSize(datatype)};
|
||||
NVTX3_FUNC_WITH_PARAMS(AllToAllv, AllToAllvSchema, payload)
|
||||
|
||||
if (mscclAvailable() && !mscclIsCaller()) {
|
||||
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
sendbuff, sendcounts, sdispls, recvbuff, recvcounts, rdispls,
|
||||
0, datatype, 0, 0, ncclSum, mscclFuncAllToAllv, comm, stream);
|
||||
@@ -166,7 +166,7 @@ ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t count, n
|
||||
NvtxParamsBroadcast payload{count * ncclTypeSize(datatype), root};
|
||||
NVTX3_FUNC_WITH_PARAMS(Broadcast, BroadcastSchema, payload)
|
||||
|
||||
if (mscclAvailable() && !mscclIsCaller()) {
|
||||
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
||||
count, datatype, root, 0, ncclSum, mscclFuncBroadcast, comm, stream);
|
||||
@@ -200,7 +200,7 @@ ncclResult_t ncclGather(const void* sendbuff, void* recvbuff, size_t sendcount,
|
||||
NvtxParamsGather payload{sendcount * ncclTypeSize(datatype), root};
|
||||
NVTX3_FUNC_WITH_PARAMS(Gather, GatherSchema, payload)
|
||||
|
||||
if (mscclAvailable() && !mscclIsCaller()) {
|
||||
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
||||
sendcount, datatype, root, 0, ncclSum, mscclFuncGather, comm, stream);
|
||||
@@ -240,7 +240,7 @@ ncclResult_t ncclReduce(const void* sendbuff, void* recvbuff, size_t count,
|
||||
NvtxParamsReduce payload{count * ncclTypeSize(datatype), root, op};
|
||||
NVTX3_FUNC_WITH_PARAMS(Reduce, ReduceSchema, payload)
|
||||
|
||||
if (mscclAvailable() && !mscclIsCaller()) {
|
||||
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
||||
count, datatype, root, 0, op, mscclFuncReduce, comm, stream);
|
||||
@@ -268,7 +268,7 @@ ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, size_t recv
|
||||
NvtxParamsReduceScatter payload{recvcount * ncclTypeSize(datatype), op};
|
||||
NVTX3_FUNC_WITH_PARAMS(ReduceScatter, ReduceScatterSchema, payload)
|
||||
|
||||
if (mscclAvailable() && !mscclIsCaller()) {
|
||||
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
||||
recvcount, datatype, 0, 0, op, mscclFuncReduceScatter, comm, stream);
|
||||
@@ -295,7 +295,7 @@ ncclResult_t ncclScatter(const void* sendbuff, void* recvbuff, size_t recvcount,
|
||||
NvtxParamsScatter payload{recvcount * ncclTypeSize(datatype), root};
|
||||
NVTX3_FUNC_WITH_PARAMS(Scatter, ScatterSchema, payload)
|
||||
|
||||
if (mscclAvailable() && !mscclIsCaller()) {
|
||||
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
||||
recvcount, datatype, root, 0, ncclSum, mscclFuncScatter, comm, stream);
|
||||
@@ -333,7 +333,7 @@ ncclResult_t ncclSend(const void* sendbuff, size_t count, ncclDataType_t datatyp
|
||||
NvtxParamsSendRecv payload{count * ncclTypeSize(datatype), peer};
|
||||
NVTX3_FUNC_WITH_PARAMS(Send, SendRecvSchema, payload)
|
||||
|
||||
if (mscclAvailable() && !mscclIsCaller()) {
|
||||
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
sendbuff, nullptr, nullptr, nullptr, nullptr, nullptr,
|
||||
count, datatype, 0, peer, ncclSum, mscclFuncSend, comm, stream);
|
||||
@@ -356,7 +356,7 @@ ncclResult_t ncclRecv(void* recvbuff, size_t count, ncclDataType_t datatype, int
|
||||
NvtxParamsSendRecv payload{count * ncclTypeSize(datatype), peer};
|
||||
NVTX3_FUNC_WITH_PARAMS(Recv, SendRecvSchema, payload)
|
||||
|
||||
if (mscclAvailable() && !mscclIsCaller()) {
|
||||
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
|
||||
return mscclEnqueueCheck(
|
||||
nullptr, nullptr, nullptr, recvbuff, nullptr, nullptr,
|
||||
count, datatype, 0, peer, ncclSum, mscclFuncRecv, comm, stream);
|
||||
|
||||
@@ -17,7 +17,7 @@ void mscclSetIsCallerFlag();
|
||||
void mscclClearIsCallerFlag();
|
||||
bool mscclIsCaller();
|
||||
|
||||
bool mscclAvailable();
|
||||
bool mscclAvailable(int rank = -1);
|
||||
|
||||
ncclResult_t mscclSchedulerInit(ncclComm_t comm, int* numChannelsRequired);
|
||||
|
||||
@@ -33,7 +33,7 @@ ncclResult_t mscclEnqueueCheck(
|
||||
|
||||
ncclResult_t mscclGroupEnd();
|
||||
|
||||
ncclResult_t mscclTeardown();
|
||||
ncclResult_t mscclTeardown(int rank);
|
||||
|
||||
size_t mscclKernMaxLocalSize();
|
||||
|
||||
|
||||
@@ -11,11 +11,11 @@
|
||||
#include "comm.h"
|
||||
#include "msccl/msccl_struct.h"
|
||||
|
||||
ncclResult_t mscclGetCaptureStatus(hipStream_t stream);
|
||||
ncclResult_t mscclGetCaptureStatus(int rank, hipStream_t stream);
|
||||
|
||||
ncclResult_t mscclSetupScratch(struct mscclAlgo* hostAlgo, hipStream_t stream);
|
||||
|
||||
ncclResult_t mscclSetupSyncFlags(hipStream_t stream);
|
||||
ncclResult_t mscclSetupSyncFlags(int rank, hipStream_t stream);
|
||||
|
||||
ncclResult_t mscclSetupConnections(struct mscclAlgo* hostAlgo, ncclComm_t comm);
|
||||
|
||||
|
||||
@@ -8,9 +8,15 @@
|
||||
|
||||
#include "msccl/msccl_struct.h"
|
||||
|
||||
mscclStatus& mscclGetStatus();
|
||||
bool mscclInitialized(int rank);
|
||||
|
||||
mscclSavedProxyArgs& mscclGetSavedProxyArgs();
|
||||
void mscclSetInitialized(int rank, bool initialized = true);
|
||||
|
||||
void mscclRemoveRank(int rank);
|
||||
|
||||
mscclStatus& mscclGetStatus(int rank);
|
||||
|
||||
mscclSavedProxyArgs& mscclGetSavedProxyArgs(int rank);
|
||||
|
||||
mscclThreadLocalStatus& mscclGetThreadLocalStatus();
|
||||
|
||||
|
||||
+2
-2
@@ -1738,7 +1738,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p
|
||||
|
||||
if (mscclEnabled() && (comm->topo->mscclEnabled || mscclForceEnabled())) {
|
||||
NCCLCHECK(mscclInit(comm));
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
mscclStatus& status = mscclGetStatus(comm->rank);
|
||||
status.needsProxy |= mscclNeedsProxy;
|
||||
}
|
||||
|
||||
@@ -2351,7 +2351,7 @@ static ncclResult_t commCleanup(ncclComm_t comm) {
|
||||
#endif
|
||||
|
||||
if (mscclEnabled() && (mscclEnabledForTopo || mscclForceEnabled())) {
|
||||
NCCLCHECK(mscclTeardown());
|
||||
NCCLCHECK(mscclTeardown(comm->rank));
|
||||
}
|
||||
|
||||
return ncclSuccess;
|
||||
|
||||
@@ -25,8 +25,6 @@
|
||||
RCCL_PARAM(MscclEnabled, "MSCCL_ENABLE", 1);
|
||||
RCCL_PARAM(MscclForceEnabled, "MSCCL_FORCE_ENABLE", 0);
|
||||
static const char* mscclAlgoFilePathEnv = "MSCCL_ALGO_FILE_PATH";
|
||||
static std::atomic<bool> mscclInitialized;
|
||||
static std::mutex mscclLifecycleMutex;
|
||||
|
||||
bool mscclEnabled() {
|
||||
#ifdef COMPILE_MSCCL_KERNEL
|
||||
@@ -56,23 +54,12 @@ bool mscclIsCaller() {
|
||||
return mscclGetThreadLocalStatus().mscclIsCallerFlag;
|
||||
}
|
||||
|
||||
bool mscclAvailable() {
|
||||
return mscclEnabled() && mscclInitialized.load(std::memory_order_acquire);
|
||||
bool mscclAvailable(int rank) {
|
||||
return mscclEnabled() && mscclInitialized(rank);
|
||||
}
|
||||
|
||||
static bool mscclCommCompatible(ncclComm_t comm) {
|
||||
std::map<uint64_t, std::set<uint64_t>> hostHashToPidHashes;
|
||||
for (int i = 0; i < comm->nRanks; i++) {
|
||||
uint64_t hostHash = comm->peerInfo[i].hostHash;
|
||||
uint64_t pidHash = comm->peerInfo[i].pidHash;
|
||||
if (hostHashToPidHashes.find(hostHash) != hostHashToPidHashes.end()) {
|
||||
auto& pidHashSet = hostHashToPidHashes[hostHash];
|
||||
if (pidHashSet.find(pidHash) != pidHashSet.end()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
hostHashToPidHashes[hostHash].insert(pidHash);
|
||||
}
|
||||
// MSCCL is always compatible now. No need to guard against multi-thread.
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -86,8 +73,8 @@ static const char* mscclAlgoShareDirPath = "../share/rccl/msccl-algorithms";
|
||||
static const char* mscclUnitTestAlgoShareDirPath = "../share/rccl/msccl-unit-test-algorithms";
|
||||
|
||||
static ncclResult_t mscclInternalSchedulerInit(ncclComm_t comm, int* numChannelsRequired) {
|
||||
static bool mscclAlgoMetaLoaded = false;
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
static thread_local bool mscclAlgoMetaLoaded = false;
|
||||
mscclStatus& status = mscclGetStatus(comm->rank);
|
||||
|
||||
*numChannelsRequired = 0;
|
||||
// Query numChannelsRequired from loaded algorithm metas
|
||||
@@ -166,9 +153,7 @@ ncclResult_t mscclSchedulerInit(ncclComm_t comm, int* numChannelsRequired) {
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(mscclLifecycleMutex);
|
||||
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
mscclStatus& status = mscclGetStatus(comm->rank);
|
||||
bool useInternalScheduler = false;
|
||||
|
||||
const char* mscclSchedulerPath = getenv(mscclSchedulerPathEnv);
|
||||
@@ -199,20 +184,11 @@ ncclResult_t mscclSchedulerInit(ncclComm_t comm, int* numChannelsRequired) {
|
||||
}
|
||||
|
||||
ncclResult_t mscclInit(ncclComm_t comm) {
|
||||
// Always initialize thread local status
|
||||
mscclThreadLocalStatus threadLocalStatus = mscclGetThreadLocalStatus();
|
||||
threadLocalStatus.groupStatus = mscclNoGroup;
|
||||
threadLocalStatus.groupDepth = 0;
|
||||
threadLocalStatus.captureId = ULLONG_MAX;
|
||||
threadLocalStatus.captureStatus = mscclNoCapture;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mscclLifecycleMutex);
|
||||
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
mscclStatus& status = mscclGetStatus(comm->rank);
|
||||
|
||||
// freeAlgoHandles and needsProxy are initialized globally once and before algorithm pre-processing and connection
|
||||
if (!mscclInitialized.load(std::memory_order_acquire)) {
|
||||
if (!mscclInitialized(comm->rank)) {
|
||||
status.freeAlgoHandles.resize(MSCCL_MAX_NUM_ALGOS);
|
||||
for (int i = 0; i < MSCCL_MAX_NUM_ALGOS; i++) {
|
||||
status.freeAlgoHandles[i] = MSCCL_MAX_NUM_ALGOS - i - 1;
|
||||
@@ -229,6 +205,8 @@ ncclResult_t mscclInit(ncclComm_t comm) {
|
||||
if (m.nRanks == comm->nRanks) {
|
||||
// Load algorithms
|
||||
if (status.rankToAlgoHandles[i].find(comm->rank) == status.rankToAlgoHandles[i].end()) {
|
||||
static std::mutex loadAlgoMutex;
|
||||
std::lock_guard<std::mutex> lock(loadAlgoMutex);
|
||||
NCCLCHECK(mscclLoadAlgo(m.filePath.c_str(), &(status.rankToAlgoHandles[i][comm->rank]), comm->rank));
|
||||
}
|
||||
// Connect algorithms
|
||||
@@ -241,7 +219,7 @@ ncclResult_t mscclInit(ncclComm_t comm) {
|
||||
}
|
||||
}
|
||||
|
||||
if (mscclInitialized.load(std::memory_order_acquire)) {
|
||||
if (mscclInitialized(comm->rank)) {
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@@ -250,7 +228,7 @@ ncclResult_t mscclInit(ncclComm_t comm) {
|
||||
status.lastStream = nullptr;
|
||||
NCCLCHECK(mscclInitWorkFifoStatus(&(status.defaultWorkFifoStatus)));
|
||||
|
||||
mscclInitialized.store(true, std::memory_order_release);
|
||||
mscclSetInitialized(comm->rank);
|
||||
}
|
||||
|
||||
INFO(NCCL_INIT, "MSCCL: Initialization finished, localSize %ld", mscclKernMaxLocalSize());
|
||||
@@ -266,8 +244,8 @@ ncclResult_t mscclGroupStart() {
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t mscclInternalSchedulerSelectAlgo(struct mscclSchedulerParam* param) {
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
static ncclResult_t mscclInternalSchedulerSelectAlgo(int rank, struct mscclSchedulerParam* param) {
|
||||
mscclStatus& status = mscclGetStatus(rank);
|
||||
param->scheduled = false;
|
||||
|
||||
// Current MSCCL doesn't support pre/post op
|
||||
@@ -312,13 +290,13 @@ static ncclResult_t mscclInternalSchedulerSelectAlgo(struct mscclSchedulerParam*
|
||||
}
|
||||
|
||||
static ncclResult_t mscclSchedulerSelectAlgo(struct mscclSavedSchedulerParam* param) {
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
mscclStatus& status = mscclGetStatus(param->comm->rank);
|
||||
if (status.mscclSchedulerPtr) {
|
||||
NCCLCHECK(status.mscclSchedulerPtr->selectAlgo(&(param->p)));
|
||||
} else {
|
||||
// Disable MSCCL algorithms if machine type is not matching
|
||||
if (param->comm->topo->mscclEnabled || mscclForceEnabled()) {
|
||||
NCCLCHECK(mscclInternalSchedulerSelectAlgo(&(param->p)));
|
||||
NCCLCHECK(mscclInternalSchedulerSelectAlgo(param->comm->rank, &(param->p)));
|
||||
} else {
|
||||
param->p.scheduled = false;
|
||||
}
|
||||
@@ -513,12 +491,30 @@ ncclResult_t mscclGroupEnd() {
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t mscclInternalSchedulerTeardown() {
|
||||
static ncclResult_t mscclInternalUnloadAlgo(int rank, mscclAlgoHandle_t mscclAlgoHandle) {
|
||||
mscclStatus& status = mscclGetStatus(rank);
|
||||
|
||||
free(status.hostAlgos[mscclAlgoHandle]);
|
||||
status.hostAlgos.erase(mscclAlgoHandle);
|
||||
|
||||
NCCLCHECK(ncclCudaFree(status.devAlgos[mscclAlgoHandle]));
|
||||
status.devAlgos.erase(mscclAlgoHandle);
|
||||
|
||||
status.freeAlgoHandles.push_back(mscclAlgoHandle);
|
||||
|
||||
for (auto &s : status.connectedAlgos) {
|
||||
s.second.erase(mscclAlgoHandle);
|
||||
}
|
||||
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
static ncclResult_t mscclInternalSchedulerTeardown(int rank) {
|
||||
ncclResult_t ret = ncclSuccess, tmpRet = ncclSuccess;
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
mscclStatus& status = mscclGetStatus(rank);
|
||||
for (auto &m : status.rankToAlgoHandles) {
|
||||
for (auto &p : m) {
|
||||
tmpRet = mscclUnloadAlgo(p.second);
|
||||
tmpRet = mscclInternalUnloadAlgo(rank, p.second);
|
||||
if (ret == ncclSuccess) {
|
||||
ret = tmpRet;
|
||||
}
|
||||
@@ -529,18 +525,13 @@ static ncclResult_t mscclInternalSchedulerTeardown() {
|
||||
return ret;
|
||||
}
|
||||
|
||||
ncclResult_t mscclTeardown() {
|
||||
// Always teardown thread local status
|
||||
mscclThreadLocalStatus threadLocalStatus = mscclGetThreadLocalStatus();
|
||||
threadLocalStatus.savedSchedulerParams.clear();
|
||||
|
||||
ncclResult_t mscclTeardown(int rank) {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mscclLifecycleMutex);
|
||||
|
||||
if (!mscclInitialized.load(std::memory_order_acquire)) {
|
||||
if (!mscclInitialized(rank)) {
|
||||
mscclRemoveRank(rank);
|
||||
return ncclSuccess;
|
||||
}
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
mscclStatus& status = mscclGetStatus(rank);
|
||||
for (auto &p : status.hostAlgos) {
|
||||
free(p.second);
|
||||
status.freeAlgoHandles.push_back(p.first);
|
||||
@@ -563,13 +554,14 @@ ncclResult_t mscclTeardown() {
|
||||
dlclose(status.mscclSchedulerLib);
|
||||
status.mscclSchedulerLib = nullptr;
|
||||
} else {
|
||||
NCCLCHECK(mscclInternalSchedulerTeardown());
|
||||
NCCLCHECK(mscclInternalSchedulerTeardown(rank));
|
||||
}
|
||||
NCCLCHECK(mscclDestroyWorkFifoStatus(&(status.defaultWorkFifoStatus)));
|
||||
for (auto &p : status.graphWorkFifoStatus) {
|
||||
NCCLCHECK(mscclDestroyWorkFifoStatus(&(p.second)));
|
||||
}
|
||||
mscclInitialized.store(false, std::memory_order_release);
|
||||
mscclSetInitialized(rank, false);
|
||||
mscclRemoveRank(rank);
|
||||
}
|
||||
|
||||
INFO(NCCL_INIT, "MSCCL: Teardown finished");
|
||||
|
||||
@@ -22,10 +22,10 @@ static inline size_t computeSizeNeeded(size_t nBytes, int nScratchChunks, int nC
|
||||
return (nBytes * (size_t)nScratchChunks) / (size_t)nChunksPerLoop;
|
||||
}
|
||||
|
||||
ncclResult_t mscclGetCaptureStatus(hipStream_t stream) {
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
ncclResult_t mscclGetCaptureStatus(int rank, hipStream_t stream) {
|
||||
mscclStatus& status = mscclGetStatus(rank);
|
||||
mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus();
|
||||
mscclSavedProxyArgs& savedProxyArgs = mscclGetSavedProxyArgs();
|
||||
mscclSavedProxyArgs& savedProxyArgs = mscclGetSavedProxyArgs(rank);
|
||||
cudaStreamCaptureStatus captureStatus;
|
||||
unsigned long long captureId;
|
||||
CUDACHECK(hipStreamGetCaptureInfo_v2(stream, &captureStatus, &captureId, &threadLocalStatus.graph, nullptr, nullptr));
|
||||
@@ -42,12 +42,12 @@ ncclResult_t mscclGetCaptureStatus(hipStream_t stream) {
|
||||
} else {
|
||||
threadLocalStatus.captureStatus = mscclNoCapture;
|
||||
}
|
||||
INFO(NCCL_NET,"mscclGetCaptureStatus: %d, captureId: %llu, size: %lu\n", threadLocalStatus.captureStatus, threadLocalStatus.captureId, mscclGetSavedProxyArgs()[captureId].size());
|
||||
INFO(NCCL_NET,"mscclGetCaptureStatus: %d, captureId: %llu, size: %lu\n", threadLocalStatus.captureStatus, threadLocalStatus.captureId, savedProxyArgs[captureId].size());
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t mscclSetupCount(struct mscclAlgo* hostAlgo, ncclComm_t comm, size_t count, ncclDataType_t dataType) {
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
mscclStatus& status = mscclGetStatus(comm->rank);
|
||||
status.stepSize = comm->buffSizes[hostAlgo->protocol] / NCCL_STEPS;
|
||||
status.chunkSteps = hostAlgo->protocol == NCCL_PROTO_SIMPLE ? hostAlgo->chunkSteps : 1;
|
||||
status.sliceSteps = hostAlgo->protocol == NCCL_PROTO_SIMPLE ? hostAlgo->sliceSteps : 1;
|
||||
@@ -69,12 +69,11 @@ ncclResult_t mscclSetupCount(struct mscclAlgo* hostAlgo, ncclComm_t comm, size_t
|
||||
}
|
||||
|
||||
ncclResult_t mscclSetupScratch(struct mscclAlgo* hostAlgo, hipStream_t stream) {
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
ncclResult_t mscclSetupSyncFlags(hipStream_t stream) {
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
ncclResult_t mscclSetupSyncFlags(int rank, hipStream_t stream) {
|
||||
mscclStatus& status = mscclGetStatus(rank);
|
||||
mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus();
|
||||
if (threadLocalStatus.captureStatus == mscclNewCapture ||
|
||||
status.workIndex > (1ULL << (8*sizeof(status.workIndex))) - 2 * NCCL_MAX_OPS - 1) {
|
||||
@@ -86,7 +85,7 @@ ncclResult_t mscclSetupSyncFlags(hipStream_t stream) {
|
||||
}
|
||||
|
||||
ncclResult_t mscclSetupConnections(struct mscclAlgo* hostAlgo, ncclComm_t comm) {
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
mscclStatus& status = mscclGetStatus(comm->rank);
|
||||
|
||||
// Check whether there are enough channels
|
||||
if (hostAlgo->nChannels > comm->nChannels) {
|
||||
@@ -124,7 +123,7 @@ ncclResult_t mscclSetupConnections(struct mscclAlgo* hostAlgo, ncclComm_t comm)
|
||||
}
|
||||
|
||||
static ncclResult_t mscclSetupProxyImpl(struct mscclAlgo* hostAlgo, ncclComm_t comm) {
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
mscclStatus& status = mscclGetStatus(comm->rank);
|
||||
mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus();
|
||||
struct ncclProxyOp proxyOp = {};
|
||||
proxyOp.connIndex = 0;
|
||||
@@ -185,9 +184,9 @@ static void HIPRT_CB mscclSetupProxyCallback(void *args) {
|
||||
}
|
||||
|
||||
ncclResult_t mscclSetupProxy(struct mscclAlgo* hostAlgo, ncclComm_t comm, hipStream_t stream) {
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
mscclStatus& status = mscclGetStatus(comm->rank);
|
||||
mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus();
|
||||
mscclSavedProxyArgs& savedProxyArgs = mscclGetSavedProxyArgs();
|
||||
mscclSavedProxyArgs& savedProxyArgs = mscclGetSavedProxyArgs(comm->rank);
|
||||
if (threadLocalStatus.captureStatus == mscclNoCapture) {
|
||||
INFO(NCCL_NET,"mscclSetupProxy: no capture\n");
|
||||
NCCLCHECK(mscclSetupProxyImpl(hostAlgo, comm));
|
||||
@@ -203,7 +202,7 @@ ncclResult_t mscclSetupProxy(struct mscclAlgo* hostAlgo, ncclComm_t comm, hipStr
|
||||
p.userData = params;
|
||||
CUDACHECK(hipGraphAddHostNode(&callbackNode, threadLocalStatus.graph, nullptr, 0, &p));
|
||||
}
|
||||
mscclGetSavedProxyArgs()[threadLocalStatus.captureId].emplace_back(hostAlgo, comm);
|
||||
savedProxyArgs[threadLocalStatus.captureId].emplace_back(hostAlgo, comm);
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
@@ -410,7 +409,7 @@ RCCL_PARAM(MscclForceFullOps, "MSCCL_FORCE_FULLOPS", 0);
|
||||
ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count,
|
||||
ncclDataType_t dataType, ncclRedOp_t op, struct mscclAlgo* hostAlgo, struct mscclAlgo* devAlgo,
|
||||
ncclComm_t comm, hipStream_t stream) {
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
mscclStatus& status = mscclGetStatus(comm->rank);
|
||||
mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus();
|
||||
|
||||
if (status.lastStream != stream && status.lastStream != nullptr) {
|
||||
|
||||
@@ -6,9 +6,60 @@
|
||||
#include "msccl/msccl_status.h"
|
||||
#include "msccl/msccl_struct.h"
|
||||
|
||||
mscclStatus& mscclGetStatus() {
|
||||
static mscclStatus status;
|
||||
return status;
|
||||
#include "debug.h"
|
||||
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <unordered_map>
|
||||
using namespace std;
|
||||
|
||||
struct mscclRankState {
|
||||
int rank;
|
||||
bool initialized;
|
||||
mscclStatus status;
|
||||
mscclSavedProxyArgs savedProxyArgs;
|
||||
|
||||
mscclRankState() : rank(-1), initialized(false), status(), savedProxyArgs() {}
|
||||
explicit mscclRankState(const mscclRankState&) = default;
|
||||
};
|
||||
|
||||
static mutex rankStatesMutex;
|
||||
static unordered_map<int, shared_ptr<mscclRankState>> rankStates;
|
||||
|
||||
static inline mscclRankState& mscclGetRankState(int rank) {
|
||||
static thread_local shared_ptr<mscclRankState> threadRankState = make_shared<mscclRankState>();
|
||||
|
||||
if (rank < 0) {
|
||||
return *threadRankState;
|
||||
}
|
||||
|
||||
lock_guard<mutex> lock(rankStatesMutex);
|
||||
|
||||
auto rankStateIt = rankStates.find(rank);
|
||||
if (rankStateIt == rankStates.end()) {
|
||||
rankStateIt = rankStates.insert(make_pair(rank, make_shared<mscclRankState>(*threadRankState))).first;
|
||||
rankStateIt->second->rank = rank;
|
||||
}
|
||||
return *(rankStateIt->second);
|
||||
}
|
||||
|
||||
bool mscclInitialized(int rank) {
|
||||
return mscclGetRankState(rank).initialized;
|
||||
}
|
||||
|
||||
void mscclSetInitialized(int rank, bool initialized) {
|
||||
auto& state = mscclGetRankState(rank);
|
||||
assert(!initialized || !state.initialized);
|
||||
state.initialized = initialized;
|
||||
}
|
||||
|
||||
void mscclRemoveRank(int rank) {
|
||||
lock_guard<mutex> lock(rankStatesMutex);
|
||||
rankStates.erase(rank);
|
||||
}
|
||||
|
||||
mscclStatus& mscclGetStatus(int rank) {
|
||||
return mscclGetRankState(rank).status;
|
||||
}
|
||||
|
||||
mscclThreadLocalStatus& mscclGetThreadLocalStatus() {
|
||||
@@ -16,7 +67,6 @@ mscclThreadLocalStatus& mscclGetThreadLocalStatus() {
|
||||
return threadLocalStatus;
|
||||
}
|
||||
|
||||
mscclSavedProxyArgs& mscclGetSavedProxyArgs() {
|
||||
static mscclSavedProxyArgs savedProxyArgs;
|
||||
return savedProxyArgs;
|
||||
mscclSavedProxyArgs& mscclGetSavedProxyArgs(int rank) {
|
||||
return mscclGetRankState(rank).savedProxyArgs;
|
||||
}
|
||||
|
||||
+5
-18
@@ -12,7 +12,7 @@
|
||||
|
||||
NCCL_API(ncclResult_t, mscclLoadAlgo, const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, int rank);
|
||||
ncclResult_t mscclLoadAlgo(const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, int rank) {
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
mscclStatus& status = mscclGetStatus(rank);
|
||||
|
||||
if (status.freeAlgoHandles.size() == 0) {
|
||||
WARN("MSCCL: MSCCL_MAX_NUM_ALGOS (%d) limit reached", MSCCL_MAX_NUM_ALGOS);
|
||||
@@ -56,17 +56,17 @@ ncclResult_t mscclRunAlgo(
|
||||
NvtxParamsMsccl payload{sendCounts[comm->rank] * ncclTypeSize(dataType), recvCounts[comm->rank] * ncclTypeSize(dataType)};
|
||||
NVTX3_FUNC_WITH_PARAMS(MSCCL, MscclSchema, payload)
|
||||
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
mscclStatus& status = mscclGetStatus(comm->rank);
|
||||
struct mscclAlgo* hostAlgo = status.hostAlgos[mscclAlgoHandle];
|
||||
struct mscclAlgo* devAlgo = status.devAlgos[mscclAlgoHandle];
|
||||
|
||||
NCCLCHECK(mscclGetCaptureStatus(stream));
|
||||
NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream));
|
||||
|
||||
NCCLCHECK(mscclSetupCount(hostAlgo, comm, count, dataType));
|
||||
|
||||
NCCLCHECK(mscclSetupScratch(hostAlgo, stream));
|
||||
|
||||
NCCLCHECK(mscclSetupSyncFlags(stream));
|
||||
NCCLCHECK(mscclSetupSyncFlags(comm->rank, stream));
|
||||
|
||||
NCCLCHECK(mscclSetupProxy(hostAlgo, comm, stream));
|
||||
|
||||
@@ -77,19 +77,6 @@ ncclResult_t mscclRunAlgo(
|
||||
|
||||
NCCL_API(ncclResult_t, mscclUnloadAlgo, mscclAlgoHandle_t mscclAlgoHandle);
|
||||
ncclResult_t mscclUnloadAlgo(mscclAlgoHandle_t mscclAlgoHandle) {
|
||||
mscclStatus& status = mscclGetStatus();
|
||||
|
||||
free(status.hostAlgos[mscclAlgoHandle]);
|
||||
status.hostAlgos.erase(mscclAlgoHandle);
|
||||
|
||||
NCCLCHECK(ncclCudaFree(status.devAlgos[mscclAlgoHandle]));
|
||||
status.devAlgos.erase(mscclAlgoHandle);
|
||||
|
||||
status.freeAlgoHandles.push_back(mscclAlgoHandle);
|
||||
|
||||
for (auto &s : status.connectedAlgos) {
|
||||
s.second.erase(mscclAlgoHandle);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
@@ -766,6 +766,7 @@ ncclResult_t pmscclRunAlgo(
|
||||
/*! @endcond */
|
||||
|
||||
/*! @brief MSCCL Unload Algorithm
|
||||
@deprecated This function has been removed from the public API.
|
||||
@details Unload MSCCL algorithm previous loaded using its handle. This API
|
||||
is expected to be called by MSCCL scheduler instead of end users.
|
||||
@return Result code. See @ref rccl_result_code for more details.
|
||||
|
||||
@@ -5,6 +5,8 @@ cmake_minimum_required(VERSION 2.8.12)
|
||||
|
||||
if(BUILD_TESTS)
|
||||
|
||||
option(OPENMP_TESTS_ENABLED "Enable OpenMP for unit tests" OFF)
|
||||
|
||||
message("Building rccl unit tests (Installed in /test/rccl-UnitTests)")
|
||||
|
||||
find_package(hsa-runtime64 PATHS /opt/rocm )
|
||||
@@ -23,6 +25,10 @@ if(BUILD_TESTS)
|
||||
find_library(ROCR_LIB ${CORE_RUNTIME_TARGET} PATHS ${ROCR_LIB_DIR} "/opt/rocm" PATH_SUFFIXES lib lib64 REQUIRED)
|
||||
endif()
|
||||
|
||||
if(OPENMP_TESTS_ENABLED)
|
||||
find_package(OpenMP REQUIRED)
|
||||
endif()
|
||||
|
||||
include_directories(${GTEST_INCLUDE_DIRS} ./common)
|
||||
|
||||
# Collect testing framework source files
|
||||
@@ -60,12 +66,23 @@ if(BUILD_TESTS)
|
||||
if(LL128_ENABLED)
|
||||
target_compile_definitions(rccl-UnitTests PRIVATE ENABLE_LL128)
|
||||
endif()
|
||||
if(OPENMP_TESTS_ENABLED)
|
||||
target_compile_definitions(rccl-UnitTests PRIVATE ENABLE_OPENMP)
|
||||
endif()
|
||||
target_compile_definitions(rccl-UnitTests PRIVATE ROCM_PATH="${ROCM_PATH}")
|
||||
|
||||
## Set rccl-UnitTests compile definitions
|
||||
if(OPENMP_TESTS_ENABLED)
|
||||
target_compile_options(rccl-UnitTests PRIVATE "${OpenMP_CXX_FLAGS}")
|
||||
endif()
|
||||
|
||||
## Set rccl-UnitTests linked libraries
|
||||
target_link_libraries(rccl-UnitTests PRIVATE ${GTEST_BOTH_LIBRARIES})
|
||||
target_link_libraries(rccl-UnitTests PRIVATE hip::host hip::device hsa-runtime64::hsa-runtime64)
|
||||
target_link_libraries(rccl-UnitTests PRIVATE Threads::Threads)
|
||||
if(OPENMP_TESTS_ENABLED)
|
||||
target_link_libraries(rccl-UnitTests PRIVATE "${OpenMP_CXX_FLAGS}")
|
||||
endif()
|
||||
|
||||
# rccl-UnitTests using static library of rccl requires passing rccl
|
||||
# through -l and -L instead of command line input.
|
||||
|
||||
@@ -108,6 +108,7 @@ namespace RcclUnitTesting
|
||||
showTiming = GetEnvVar("UT_SHOW_TIMING", 1);
|
||||
useInteractive = GetEnvVar("UT_INTERACTIVE", 0);
|
||||
timeoutUs = GetEnvVar("UT_TIMEOUT_US" , 5000000);
|
||||
useMultithreading = GetEnvVar("UT_MULTITHREAD", false);
|
||||
|
||||
// Total number of reduction ops
|
||||
int numOps = ncclNumOps;
|
||||
@@ -232,7 +233,8 @@ namespace RcclUnitTesting
|
||||
std::make_tuple("UT_PRINT_VALUES" , printValues , "Print array values (-1 for all)"),
|
||||
std::make_tuple("UT_SHOW_TIMING" , showTiming , "Show timing table"),
|
||||
std::make_tuple("UT_INTERACTIVE" , useInteractive, "Run in interactive mode"),
|
||||
std::make_tuple("UT_TIMEOUT_US" , timeoutUs , "Timeout limit for collective calls in us")
|
||||
std::make_tuple("UT_TIMEOUT_US" , timeoutUs , "Timeout limit for collective calls in us"),
|
||||
std::make_tuple("UT_MULTITHREAD" , useMultithreading, "Multi-thread single-process ranks"),
|
||||
};
|
||||
|
||||
printf("================================================================================\n");
|
||||
|
||||
@@ -29,6 +29,7 @@ namespace RcclUnitTesting
|
||||
bool showTiming; // Show timing per case at end [UT_SHOW_TIMING]
|
||||
bool useInteractive; // Run in interactive mode [UT_INTERACTIVE]
|
||||
int timeoutUs; // Set timeout for child in microseconds [UT_TIMEOUT_US]
|
||||
bool useMultithreading; // Multi-thread single-process ranks [UT_MULTITHREAD]
|
||||
bool isGfx94; // Detects if architecture is gfx94
|
||||
|
||||
// Constructor that parses and collects environment variables
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
|
||||
namespace RcclUnitTesting
|
||||
{
|
||||
typedef enum
|
||||
typedef enum : int
|
||||
{
|
||||
TEST_SUCCESS = 0,
|
||||
TEST_FAIL = 1,
|
||||
@@ -17,25 +17,41 @@ namespace RcclUnitTesting
|
||||
|
||||
#define ERROR(...) printf("\033[0;31m" "[ ERROR ] " "\033[0m" __VA_ARGS__)
|
||||
#define INFO(...) printf("[ INFO ] " __VA_ARGS__)
|
||||
#define WARN(...) printf("[ WARNING ] " __VA_ARGS__)
|
||||
#define RETURN_RESULT(result) return (result)
|
||||
|
||||
#define CHECK_CALL(func) \
|
||||
{ \
|
||||
#define CHECK_CALL_BASE(func, RESULT, RESULT_ARGS...) \
|
||||
do { \
|
||||
ErrCode status = func; \
|
||||
if (status != TEST_SUCCESS) \
|
||||
{ \
|
||||
ERROR("Error in call %s\n", #func); \
|
||||
return status; \
|
||||
RESULT(status, ##RESULT_ARGS); \
|
||||
} \
|
||||
}
|
||||
} while (false)
|
||||
#define CHECK_CALL(func) CHECK_CALL_BASE(func, RETURN_RESULT)
|
||||
|
||||
#define CHECK_HIP(func) \
|
||||
{ \
|
||||
#define CHECK_HIP_BASE(func, RESULT, RESULT_ARGS...) \
|
||||
do { \
|
||||
hipError_t error = (func); \
|
||||
if (error != hipSuccess) \
|
||||
{ \
|
||||
fprintf(stderr, "\033[0;31m" "[ ERROR ] HIP error: %s File:%s Line:%d\n" "\033[m", \
|
||||
hipGetErrorString(error), strrchr("/" __FILE__, '/') + 1, __LINE__); \
|
||||
return TEST_FAIL; \
|
||||
RESULT(TEST_FAIL, ##RESULT_ARGS); \
|
||||
} \
|
||||
}
|
||||
} while (false)
|
||||
#define CHECK_HIP(func) CHECK_HIP_BASE(func, RETURN_RESULT)
|
||||
|
||||
#ifdef ENABLE_OPENMP
|
||||
#define OMP_CANCEL_FOR(result, errCode) errCode = (result); _Pragma("omp cancel for")
|
||||
#define RANK_RESULT(errCode, result) OMP_CANCEL_FOR(result, errCode)
|
||||
#define CHECK_CALL_RANK(errCode, func) CHECK_CALL_BASE(func, OMP_CANCEL_FOR, errCode)
|
||||
#define CHECK_HIP_RANK(errCode, func) CHECK_HIP_BASE(func, OMP_CANCEL_FOR, errCode)
|
||||
#else
|
||||
#define RANK_RESULT(errCode, result) RETURN_RESULT(result)
|
||||
#define CHECK_CALL_RANK(errCode, func) CHECK_CALL(func)
|
||||
#define CHECK_HIP_RANK(errCode, func) CHECK_HIP(func)
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -97,7 +97,7 @@ namespace RcclUnitTesting
|
||||
childList.resize(this->numActiveChildren);
|
||||
for (int childId = 0; childId < this->numActiveChildren; ++childId)
|
||||
{
|
||||
childList[childId] = new TestBedChild(childId, ev.verbose, ev.printValues);
|
||||
childList[childId] = new TestBedChild(childId, ev.verbose, ev.printValues, ev.useMultithreading);
|
||||
if (childList[childId]->InitPipes() != TEST_SUCCESS)
|
||||
{
|
||||
ERROR("Unable to create pipes to child process\n");
|
||||
|
||||
@@ -8,20 +8,33 @@
|
||||
|
||||
#include <thread>
|
||||
#include <execinfo.h>
|
||||
#ifdef ENABLE_OPENMP
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
#define CHILD_NCCL_CALL(cmd, msg) \
|
||||
{ \
|
||||
static int getThreadId()
|
||||
{
|
||||
#ifdef ENABLE_OPENMP
|
||||
return (int)omp_get_thread_num();
|
||||
#else
|
||||
return -1;
|
||||
#endif
|
||||
}
|
||||
|
||||
#define CHILD_NCCL_CALL_BASE(cmd, msg, RESULT, RESULT_ARGS...) \
|
||||
do { \
|
||||
if (this->verbose) printf("[ NCCL CALL] " #cmd "\n"); \
|
||||
ncclResult_t status = cmd; \
|
||||
if (status != ncclSuccess) \
|
||||
{ \
|
||||
ERROR("Child process %d fails NCCL call %s with code %d\n", this->childId, msg, status); \
|
||||
return TEST_FAIL; \
|
||||
RESULT(TEST_FAIL, ##RESULT_ARGS); \
|
||||
} \
|
||||
}
|
||||
} while (false)
|
||||
#define CHILD_NCCL_CALL(cmd, msg) CHILD_NCCL_CALL_BASE(cmd, msg, RETURN_RESULT)
|
||||
|
||||
#define CHILD_NCCL_CALL_NON_BLOCKING(msg, localRank) \
|
||||
{ \
|
||||
#define CHILD_NCCL_CALL_NON_BLOCKING_BASE(msg, localRank, RESULT, RESULT_ARGS...) \
|
||||
do { \
|
||||
unsigned long int loop_counter = 0; \
|
||||
ncclResult_t ncclAsyncErr; \
|
||||
loop_counter = 0; \
|
||||
@@ -34,20 +47,30 @@
|
||||
if (ncclAsyncErr != ncclSuccess) \
|
||||
{ \
|
||||
ERROR("Child process %d fails NCCL call %s with code %d\n", this->childId, msg, ncclAsyncErr); \
|
||||
return TEST_FAIL; \
|
||||
RESULT(TEST_FAIL, ##RESULT_ARGS); \
|
||||
} \
|
||||
}
|
||||
} while (false)
|
||||
#define CHILD_NCCL_CALL_NON_BLOCKING(msg, localRank) CHILD_NCCL_CALL_NON_BLOCKING_BASE(msg, localRank, RETURN_RESULT)
|
||||
|
||||
#define PIPE_READ(val) \
|
||||
if (read(childReadFd, &val, sizeof(val)) != sizeof(val)) return TEST_FAIL;
|
||||
|
||||
#ifdef ENABLE_OPENMP
|
||||
#define CHILD_NCCL_CALL_RANK(errCode, cmd, msg) CHILD_NCCL_CALL_BASE(cmd, msg, OMP_CANCEL_FOR, errCode)
|
||||
#define CHILD_NCCL_CALL_NON_BLOCKING_RANK(errCode, msg, localRank) CHILD_NCCL_CALL_NON_BLOCKING_BASE(msg, localRank, OMP_CANCEL_FOR, errCode)
|
||||
#else
|
||||
#define CHILD_NCCL_CALL_RANK(errCode, cmd, msg) CHILD_NCCL_CALL(cmd, msg)
|
||||
#define CHILD_NCCL_CALL_NON_BLOCKING_RANK(errCode, msg, localRank) CHILD_NCCL_CALL_NON_BLOCKING(msg, localRank)
|
||||
#endif
|
||||
|
||||
namespace RcclUnitTesting
|
||||
{
|
||||
TestBedChild::TestBedChild(int const childId, bool const verbose, int const printValues)
|
||||
TestBedChild::TestBedChild(int const childId, bool const verbose, int const printValues, bool const useRankThreading)
|
||||
{
|
||||
this->childId = childId;
|
||||
this->verbose = verbose;
|
||||
this->printValues = printValues;
|
||||
this->useRankThreading = useRankThreading;
|
||||
}
|
||||
|
||||
int TestBedChild::InitPipes()
|
||||
@@ -83,6 +106,9 @@ namespace RcclUnitTesting
|
||||
|
||||
// Wait for commands from parent process
|
||||
if (verbose) INFO("Child %d enters execution loop\n", this->childId);
|
||||
#ifndef ENABLE_OPENMP
|
||||
if (verbose && useRankThreading) WARN("Multi-threaded ranks requires ENABLE_OPENMP to be defined\n");
|
||||
#endif
|
||||
int command;
|
||||
while (read(childReadFd, &command, sizeof(command)) > 0)
|
||||
{
|
||||
@@ -473,6 +499,8 @@ namespace RcclUnitTesting
|
||||
}
|
||||
}
|
||||
|
||||
int numThreadsToUse = this->useRankThreading ? numRanksToExecute : 1;
|
||||
|
||||
// Start group call
|
||||
CHILD_NCCL_CALL(ncclGroupStart(), "ncclGroupStart");
|
||||
|
||||
@@ -480,9 +508,17 @@ namespace RcclUnitTesting
|
||||
for (int collId = 0; collId < this->numCollectivesInGroup[groupId]; ++collId)
|
||||
{
|
||||
// Loop over all local ranks
|
||||
if (this->verbose && this->useRankThreading)
|
||||
INFO("Group %d collective %d running %d threads\n", groupId, collId, numThreadsToUse);
|
||||
ErrCode errCode = TEST_SUCCESS;
|
||||
auto& errCodeVal = reinterpret_cast<int&>(errCode);
|
||||
#pragma omp parallel for num_threads(numThreadsToUse) reduction(max : errCodeVal)
|
||||
for (int localRank : localRanksToExecute)
|
||||
{
|
||||
CHECK_HIP(hipSetDevice(this->deviceIds[localRank]));
|
||||
if (this->verbose && this->useRankThreading)
|
||||
INFO("Group %d collective %d running rank %d on thread %d\n", groupId, collId, localRank, getThreadId());
|
||||
|
||||
CHECK_HIP_RANK(errCode, hipSetDevice(this->deviceIds[localRank]));
|
||||
|
||||
CollectiveArgs const& collArg = this->collArgs[groupId][localRank][collId];
|
||||
|
||||
@@ -492,14 +528,14 @@ namespace RcclUnitTesting
|
||||
PtrUnion inputCpu;
|
||||
size_t const numInputBytes = numInputElementsToPrint * DataTypeToBytes(collArg.dataType);
|
||||
inputCpu.AllocateCpuMem(numInputBytes);
|
||||
CHECK_HIP(hipMemcpy(inputCpu.ptr, collArg.inputGpu.ptr, numInputBytes, hipMemcpyDeviceToHost));
|
||||
CHECK_HIP_RANK(errCode, hipMemcpy(inputCpu.ptr, collArg.inputGpu.ptr, numInputBytes, hipMemcpyDeviceToHost));
|
||||
printf("[ DEBUG ] Rank %02d Group %d Coll %d %-10s: %s\n", collArg.globalRank, groupId, collId, "Input",
|
||||
inputCpu.ToString(collArg.dataType, numInputElementsToPrint).c_str());
|
||||
inputCpu.FreeCpuMem();
|
||||
|
||||
int const numOutputElementsToPrint = (this->printValues < 0 ? collArg.numOutputElements : this->printValues);
|
||||
size_t const numOutputBytes = numOutputElementsToPrint * DataTypeToBytes(collArg.dataType);
|
||||
CHECK_HIP(hipMemcpy(collArg.outputCpu.ptr, collArg.outputGpu.ptr, numOutputBytes, hipMemcpyDeviceToHost));
|
||||
CHECK_HIP_RANK(errCode, hipMemcpy(collArg.outputCpu.ptr, collArg.outputGpu.ptr, numOutputBytes, hipMemcpyDeviceToHost));
|
||||
printf("[ DEBUG ] Rank %02d Group %d Coll %d %-10s: %s\n", collArg.globalRank, groupId, collId, "Pre-Output",
|
||||
collArg.outputCpu.ToString(collArg.dataType, numOutputElementsToPrint).c_str());
|
||||
}
|
||||
@@ -507,7 +543,8 @@ namespace RcclUnitTesting
|
||||
switch (collArg.funcType)
|
||||
{
|
||||
case ncclCollBroadcast:
|
||||
CHILD_NCCL_CALL(ncclBroadcast(collArg.inputGpu.ptr,
|
||||
CHILD_NCCL_CALL_RANK(errCode, ncclBroadcast(
|
||||
collArg.inputGpu.ptr,
|
||||
collArg.outputGpu.ptr,
|
||||
collArg.numInputElements,
|
||||
collArg.dataType,
|
||||
@@ -517,7 +554,8 @@ namespace RcclUnitTesting
|
||||
"ncclBroadcast");
|
||||
break;
|
||||
case ncclCollReduce:
|
||||
CHILD_NCCL_CALL(ncclReduce(collArg.inputGpu.ptr,
|
||||
CHILD_NCCL_CALL_RANK(errCode, ncclReduce(
|
||||
collArg.inputGpu.ptr,
|
||||
collArg.outputGpu.ptr,
|
||||
collArg.numInputElements,
|
||||
collArg.dataType,
|
||||
@@ -528,7 +566,8 @@ namespace RcclUnitTesting
|
||||
"ncclReduce");
|
||||
break;
|
||||
case ncclCollAllGather:
|
||||
CHILD_NCCL_CALL(ncclAllGather(collArg.inputGpu.ptr,
|
||||
CHILD_NCCL_CALL_RANK(errCode, ncclAllGather(
|
||||
collArg.inputGpu.ptr,
|
||||
collArg.outputGpu.ptr,
|
||||
collArg.numInputElements,
|
||||
collArg.dataType,
|
||||
@@ -537,7 +576,8 @@ namespace RcclUnitTesting
|
||||
"ncclAllGather");
|
||||
break;
|
||||
case ncclCollReduceScatter:
|
||||
CHILD_NCCL_CALL(ncclReduceScatter(collArg.inputGpu.ptr,
|
||||
CHILD_NCCL_CALL_RANK(errCode, ncclReduceScatter(
|
||||
collArg.inputGpu.ptr,
|
||||
collArg.outputGpu.ptr,
|
||||
collArg.numOutputElements,
|
||||
collArg.dataType,
|
||||
@@ -547,7 +587,8 @@ namespace RcclUnitTesting
|
||||
"ncclReduceScatter");
|
||||
break;
|
||||
case ncclCollAllReduce:
|
||||
CHILD_NCCL_CALL(ncclAllReduce(collArg.inputGpu.ptr,
|
||||
CHILD_NCCL_CALL_RANK(errCode, ncclAllReduce(
|
||||
collArg.inputGpu.ptr,
|
||||
collArg.outputGpu.ptr,
|
||||
collArg.numInputElements,
|
||||
collArg.dataType,
|
||||
@@ -557,7 +598,8 @@ namespace RcclUnitTesting
|
||||
"ncclAllReduce");
|
||||
break;
|
||||
case ncclCollGather:
|
||||
CHILD_NCCL_CALL(ncclGather(collArg.inputGpu.ptr,
|
||||
CHILD_NCCL_CALL_RANK(errCode, ncclGather(
|
||||
collArg.inputGpu.ptr,
|
||||
collArg.outputGpu.ptr,
|
||||
collArg.numInputElements,
|
||||
collArg.dataType,
|
||||
@@ -567,7 +609,8 @@ namespace RcclUnitTesting
|
||||
"ncclGather");
|
||||
break;
|
||||
case ncclCollScatter:
|
||||
CHILD_NCCL_CALL(ncclScatter(collArg.inputGpu.ptr,
|
||||
CHILD_NCCL_CALL_RANK(errCode, ncclScatter(
|
||||
collArg.inputGpu.ptr,
|
||||
collArg.outputGpu.ptr,
|
||||
collArg.numOutputElements,
|
||||
collArg.dataType,
|
||||
@@ -577,7 +620,8 @@ namespace RcclUnitTesting
|
||||
"ncclScatter");
|
||||
break;
|
||||
case ncclCollAllToAll:
|
||||
CHILD_NCCL_CALL(ncclAllToAll(collArg.inputGpu.ptr,
|
||||
CHILD_NCCL_CALL_RANK(errCode, ncclAllToAll(
|
||||
collArg.inputGpu.ptr,
|
||||
collArg.outputGpu.ptr,
|
||||
collArg.numInputElements / collArg.totalRanks,
|
||||
collArg.dataType,
|
||||
@@ -586,7 +630,8 @@ namespace RcclUnitTesting
|
||||
"ncclAllToAll");
|
||||
break;
|
||||
case ncclCollAllToAllv:
|
||||
CHILD_NCCL_CALL(ncclAllToAllv(collArg.inputGpu.ptr,
|
||||
CHILD_NCCL_CALL_RANK(errCode, ncclAllToAllv(
|
||||
collArg.inputGpu.ptr,
|
||||
collArg.options.sendcounts + (this->rankOffset + localRank)*this->totalRanks,
|
||||
collArg.options.sdispls + (this->rankOffset + localRank)*this->totalRanks,
|
||||
collArg.outputGpu.ptr,
|
||||
@@ -598,7 +643,8 @@ namespace RcclUnitTesting
|
||||
"ncclAllToAllv");
|
||||
break;
|
||||
case ncclCollSend:
|
||||
CHILD_NCCL_CALL(ncclSend(collArg.inputGpu.ptr,
|
||||
CHILD_NCCL_CALL_RANK(errCode, ncclSend(
|
||||
collArg.inputGpu.ptr,
|
||||
collArg.numInputElements,
|
||||
collArg.dataType,
|
||||
collArg.options.root,
|
||||
@@ -607,7 +653,8 @@ namespace RcclUnitTesting
|
||||
"ncclSend");
|
||||
break;
|
||||
case ncclCollRecv:
|
||||
CHILD_NCCL_CALL(ncclRecv(collArg.outputGpu.ptr,
|
||||
CHILD_NCCL_CALL_RANK(errCode, ncclRecv(
|
||||
collArg.outputGpu.ptr,
|
||||
collArg.numOutputElements,
|
||||
collArg.dataType,
|
||||
collArg.options.root,
|
||||
@@ -617,14 +664,18 @@ namespace RcclUnitTesting
|
||||
break;
|
||||
default:
|
||||
ERROR("Unknown func type %d\n", collArg.funcType);
|
||||
return TEST_FAIL;
|
||||
RANK_RESULT(errCode, TEST_FAIL);
|
||||
}
|
||||
if (this->useBlocking == false)
|
||||
{
|
||||
CHILD_NCCL_CALL_NON_BLOCKING("ncclCommGetAsyncErrorExecuteCollectives", localRank);
|
||||
CHILD_NCCL_CALL_NON_BLOCKING_RANK(errCode, "ncclCommGetAsyncErrorExecuteCollectives", localRank);
|
||||
}
|
||||
|
||||
if (this->verbose && this->useRankThreading)
|
||||
INFO("Group %d collective %d done rank %d on thread %d\n", groupId, collId, localRank, getThreadId());
|
||||
}
|
||||
|
||||
if (this->useRankThreading) CHECK_CALL(errCode);
|
||||
}
|
||||
// End group call
|
||||
if (this->useBlocking == false)
|
||||
|
||||
@@ -57,6 +57,7 @@ namespace RcclUnitTesting
|
||||
pid_t pid;
|
||||
bool verbose;
|
||||
int printValues;
|
||||
bool useRankThreading;
|
||||
|
||||
// Pipes used to communicate between parent process
|
||||
int parentWriteFd;
|
||||
@@ -80,7 +81,7 @@ namespace RcclUnitTesting
|
||||
std::vector<std::vector<std::vector<bool>>> graphEnabled;
|
||||
|
||||
// Constructor
|
||||
TestBedChild(int const childId, bool const verbose, int const printValues);
|
||||
TestBedChild(int const childId, bool const verbose, int const printValues, bool const useRankThreading);
|
||||
|
||||
// Prepare parent/child communication pipes - to be executed by parent process
|
||||
int InitPipes();
|
||||
|
||||
Ссылка в новой задаче
Block a user