Enable multi-threading for MSCCL (#1203)

MSCCL can now run in a multi-threaded configuration. To test in the unit tests, added the ENABLE_OPENMP compile definition flag and the --openmp-test-enable flag to the unit test build script. To activate, set the environment variables UT_MULTITHREADED=1 and UT_PROCESS_MASK=1. Set Jenkins to use this mode.
Tá an tiomantas seo le fáil i:
corey-derochie-amd
2024-07-04 09:34:38 -06:00
tiomanta ag GitHub
tuismitheoir 45f3fbc52f
tiomantas 0c36d571ea
D'athraigh 19 comhad le 279 breiseanna agus 148 scriosta
+1 -1
Féach ar an gComhad
@@ -24,7 +24,7 @@ def runTestCommand (platform, project, gfilter, envars)
cd ${project.paths.project_build_prefix}/build/release/test
${sudo} ulimit -l unlimited
ulimit -a
${sudo} ${envars} RCCL_ENABLE_SIGNALHANDLER=1 NCCL_DEBUG=INFO HSA_FORCE_FINE_GRAIN_PCIE=1 ./rccl-UnitTests --gtest_filter=${gfilter} --gtest_output=xml --gtest_color=yes
${sudo} ${envars} RCCL_ENABLE_SIGNALHANDLER=1 NCCL_DEBUG=INFO HSA_FORCE_FINE_GRAIN_PCIE=1 UT_MULTITHREAD=1 UT_PROCESS_MASK=1 ./rccl-UnitTests --gtest_filter=${gfilter} --gtest_output=xml --gtest_color=yes
"""
platform.runCommand(this, command)
+9 -1
Féach ar an gComhad
@@ -26,6 +26,7 @@ install_prefix="${ROCM_PATH}"
msccl_kernel_enabled=true
num_parallel_jobs=$(nproc)
npkit_enabled=false
openmp_test_enabled=false
roctx_enabled=false
run_tests=false
run_tests_all=false
@@ -52,6 +53,7 @@ function display_help()
echo " --amdgpu_targets Only compile for specified GPU architecture(s). For multiple targets, seperate by ';' (builds for all supported GPU architectures by default)"
echo " --no_clean Don't delete files if they already exist"
echo " --npkit-enable Compile with npkit enabled"
echo " --openmp-test-enable Enable OpenMP in rccl unit tests"
echo " --roctx-enable Compile with roctx enabled (example usage: rocprof --roctx-trace ./rccl-program)"
echo " -p|--package_build Build RCCL package"
echo " --prefix Specify custom directory to install RCCL to (default: \`/opt/rocm\`)"
@@ -71,7 +73,7 @@ function display_help()
# check if we have a modern version of getopt that can handle whitespace and long parameters
getopt -T
if [[ "$?" -eq 4 ]]; then
GETOPT_PARSE=$(getopt --name "${0}" --options dfhij:lprt --longoptions address-sanitizer,dependencies,debug,enable_backtrace,disable-colltrace,disable-msccl-kernel,fast,help,install,jobs:,local_gpu_only,amdgpu_targets:,no_clean,npkit-enable,roctx-enable,package_build,prefix:,rm-legacy-include-dir,run_tests_all,run_tests_quick,static,tests_build,time-trace,verbose -- "$@")
GETOPT_PARSE=$(getopt --name "${0}" --options dfhij:lprt --longoptions address-sanitizer,dependencies,debug,enable_backtrace,disable-colltrace,disable-msccl-kernel,fast,help,install,jobs:,local_gpu_only,amdgpu_targets:,no_clean,npkit-enable,openmp-test-enable,roctx-enable,package_build,prefix:,rm-legacy-include-dir,run_tests_all,run_tests_quick,static,tests_build,time-trace,verbose -- "$@")
else
echo "Need a new version of getopt"
exit 1
@@ -100,6 +102,7 @@ while true; do
--amdgpu_targets) build_amdgpu_targets=${2}; shift 2 ;;
--no_clean) clean_build=false; shift ;;
--npkit-enable) npkit_enabled=true; shift ;;
--openmp-test-enable) openmp_test_enabled=true; shift ;;
--roctx-enable) roctx_enabled=true; shift ;;
-p | --package_build) build_package=true; shift ;;
--prefix) install_library=true; install_prefix=${2}; shift 2 ;;
@@ -246,6 +249,11 @@ if [[ "${roctx_enabled}" == true ]]; then
cmake_common_options="${cmake_common_options} -DROCTX=ON"
fi
# Enable OpenMP in unit tests
if [[ "${openmp_test_enabled}" == true ]]; then
cmake_common_options="${cmake_common_options} -DOPENMP_TESTS_ENABLED=ON"
fi
# Enable NPKit
npkit_options=""
if [[ "${npkit_enabled}" == true ]]; then
+11 -11
Féach ar an gComhad
@@ -23,7 +23,7 @@ ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcoun
size_t msgsize = sendcount * ncclTypeSize(datatype);
NVTX3_FUNC_WITH_PARAMS(AllGather, AllGatherSchema, msgsize)
if (mscclAvailable() && !mscclIsCaller()) {
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
return mscclEnqueueCheck(
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
sendcount, datatype, 0, 0, ncclSum, mscclFuncAllGather, comm, stream);
@@ -52,7 +52,7 @@ ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count,
NvtxParamsAllReduce payload{count * ncclTypeSize(datatype), op};
NVTX3_FUNC_WITH_PARAMS(AllReduce, AllReduceSchema, payload)
if (mscclAvailable() && !mscclIsCaller()) {
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
return mscclEnqueueCheck(
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
count, datatype, 0, 0, op, mscclFuncAllReduce, comm, stream);
@@ -75,7 +75,7 @@ ncclResult_t ncclAllToAll(const void* sendbuff, void* recvbuff, size_t count, nc
size_t msgsize = count * ncclTypeSize(datatype);
NVTX3_FUNC_WITH_PARAMS(AllToAll, AllToAllSchema, msgsize)
if (mscclAvailable() && !mscclIsCaller()) {
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
return mscclEnqueueCheck(
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
count, datatype, 0, 0, ncclSum, mscclFuncAllToAll, comm, stream);
@@ -122,7 +122,7 @@ ncclResult_t ncclAllToAllv(const void *sendbuff, const size_t sendcounts[], cons
NvtxParamsAllToAllv payload{sendcounts[comm->rank] * ncclTypeSize(datatype), recvcounts[comm->rank] * ncclTypeSize(datatype)};
NVTX3_FUNC_WITH_PARAMS(AllToAllv, AllToAllvSchema, payload)
if (mscclAvailable() && !mscclIsCaller()) {
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
return mscclEnqueueCheck(
sendbuff, sendcounts, sdispls, recvbuff, recvcounts, rdispls,
0, datatype, 0, 0, ncclSum, mscclFuncAllToAllv, comm, stream);
@@ -166,7 +166,7 @@ ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t count, n
NvtxParamsBroadcast payload{count * ncclTypeSize(datatype), root};
NVTX3_FUNC_WITH_PARAMS(Broadcast, BroadcastSchema, payload)
if (mscclAvailable() && !mscclIsCaller()) {
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
return mscclEnqueueCheck(
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
count, datatype, root, 0, ncclSum, mscclFuncBroadcast, comm, stream);
@@ -200,7 +200,7 @@ ncclResult_t ncclGather(const void* sendbuff, void* recvbuff, size_t sendcount,
NvtxParamsGather payload{sendcount * ncclTypeSize(datatype), root};
NVTX3_FUNC_WITH_PARAMS(Gather, GatherSchema, payload)
if (mscclAvailable() && !mscclIsCaller()) {
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
return mscclEnqueueCheck(
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
sendcount, datatype, root, 0, ncclSum, mscclFuncGather, comm, stream);
@@ -240,7 +240,7 @@ ncclResult_t ncclReduce(const void* sendbuff, void* recvbuff, size_t count,
NvtxParamsReduce payload{count * ncclTypeSize(datatype), root, op};
NVTX3_FUNC_WITH_PARAMS(Reduce, ReduceSchema, payload)
if (mscclAvailable() && !mscclIsCaller()) {
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
return mscclEnqueueCheck(
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
count, datatype, root, 0, op, mscclFuncReduce, comm, stream);
@@ -268,7 +268,7 @@ ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, size_t recv
NvtxParamsReduceScatter payload{recvcount * ncclTypeSize(datatype), op};
NVTX3_FUNC_WITH_PARAMS(ReduceScatter, ReduceScatterSchema, payload)
if (mscclAvailable() && !mscclIsCaller()) {
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
return mscclEnqueueCheck(
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
recvcount, datatype, 0, 0, op, mscclFuncReduceScatter, comm, stream);
@@ -295,7 +295,7 @@ ncclResult_t ncclScatter(const void* sendbuff, void* recvbuff, size_t recvcount,
NvtxParamsScatter payload{recvcount * ncclTypeSize(datatype), root};
NVTX3_FUNC_WITH_PARAMS(Scatter, ScatterSchema, payload)
if (mscclAvailable() && !mscclIsCaller()) {
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
return mscclEnqueueCheck(
sendbuff, nullptr, nullptr, recvbuff, nullptr, nullptr,
recvcount, datatype, root, 0, ncclSum, mscclFuncScatter, comm, stream);
@@ -333,7 +333,7 @@ ncclResult_t ncclSend(const void* sendbuff, size_t count, ncclDataType_t datatyp
NvtxParamsSendRecv payload{count * ncclTypeSize(datatype), peer};
NVTX3_FUNC_WITH_PARAMS(Send, SendRecvSchema, payload)
if (mscclAvailable() && !mscclIsCaller()) {
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
return mscclEnqueueCheck(
sendbuff, nullptr, nullptr, nullptr, nullptr, nullptr,
count, datatype, 0, peer, ncclSum, mscclFuncSend, comm, stream);
@@ -356,7 +356,7 @@ ncclResult_t ncclRecv(void* recvbuff, size_t count, ncclDataType_t datatype, int
NvtxParamsSendRecv payload{count * ncclTypeSize(datatype), peer};
NVTX3_FUNC_WITH_PARAMS(Recv, SendRecvSchema, payload)
if (mscclAvailable() && !mscclIsCaller()) {
if (mscclAvailable(comm->rank) && !mscclIsCaller()) {
return mscclEnqueueCheck(
nullptr, nullptr, nullptr, recvbuff, nullptr, nullptr,
count, datatype, 0, peer, ncclSum, mscclFuncRecv, comm, stream);
+2 -2
Féach ar an gComhad
@@ -17,7 +17,7 @@ void mscclSetIsCallerFlag();
void mscclClearIsCallerFlag();
bool mscclIsCaller();
bool mscclAvailable();
bool mscclAvailable(int rank = -1);
ncclResult_t mscclSchedulerInit(ncclComm_t comm, int* numChannelsRequired);
@@ -33,7 +33,7 @@ ncclResult_t mscclEnqueueCheck(
ncclResult_t mscclGroupEnd();
ncclResult_t mscclTeardown();
ncclResult_t mscclTeardown(int rank);
size_t mscclKernMaxLocalSize();
+2 -2
Féach ar an gComhad
@@ -11,11 +11,11 @@
#include "comm.h"
#include "msccl/msccl_struct.h"
ncclResult_t mscclGetCaptureStatus(hipStream_t stream);
ncclResult_t mscclGetCaptureStatus(int rank, hipStream_t stream);
ncclResult_t mscclSetupScratch(struct mscclAlgo* hostAlgo, hipStream_t stream);
ncclResult_t mscclSetupSyncFlags(hipStream_t stream);
ncclResult_t mscclSetupSyncFlags(int rank, hipStream_t stream);
ncclResult_t mscclSetupConnections(struct mscclAlgo* hostAlgo, ncclComm_t comm);
+8 -2
Féach ar an gComhad
@@ -8,9 +8,15 @@
#include "msccl/msccl_struct.h"
mscclStatus& mscclGetStatus();
bool mscclInitialized(int rank);
mscclSavedProxyArgs& mscclGetSavedProxyArgs();
void mscclSetInitialized(int rank, bool initialized = true);
void mscclRemoveRank(int rank);
mscclStatus& mscclGetStatus(int rank);
mscclSavedProxyArgs& mscclGetSavedProxyArgs(int rank);
mscclThreadLocalStatus& mscclGetThreadLocalStatus();
+2 -2
Féach ar an gComhad
@@ -1738,7 +1738,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p
if (mscclEnabled() && (comm->topo->mscclEnabled || mscclForceEnabled())) {
NCCLCHECK(mscclInit(comm));
mscclStatus& status = mscclGetStatus();
mscclStatus& status = mscclGetStatus(comm->rank);
status.needsProxy |= mscclNeedsProxy;
}
@@ -2351,7 +2351,7 @@ static ncclResult_t commCleanup(ncclComm_t comm) {
#endif
if (mscclEnabled() && (mscclEnabledForTopo || mscclForceEnabled())) {
NCCLCHECK(mscclTeardown());
NCCLCHECK(mscclTeardown(comm->rank));
}
return ncclSuccess;
+44 -52
Féach ar an gComhad
@@ -25,8 +25,6 @@
RCCL_PARAM(MscclEnabled, "MSCCL_ENABLE", 1);
RCCL_PARAM(MscclForceEnabled, "MSCCL_FORCE_ENABLE", 0);
static const char* mscclAlgoFilePathEnv = "MSCCL_ALGO_FILE_PATH";
static std::atomic<bool> mscclInitialized;
static std::mutex mscclLifecycleMutex;
bool mscclEnabled() {
#ifdef COMPILE_MSCCL_KERNEL
@@ -56,23 +54,12 @@ bool mscclIsCaller() {
return mscclGetThreadLocalStatus().mscclIsCallerFlag;
}
bool mscclAvailable() {
return mscclEnabled() && mscclInitialized.load(std::memory_order_acquire);
bool mscclAvailable(int rank) {
return mscclEnabled() && mscclInitialized(rank);
}
static bool mscclCommCompatible(ncclComm_t comm) {
std::map<uint64_t, std::set<uint64_t>> hostHashToPidHashes;
for (int i = 0; i < comm->nRanks; i++) {
uint64_t hostHash = comm->peerInfo[i].hostHash;
uint64_t pidHash = comm->peerInfo[i].pidHash;
if (hostHashToPidHashes.find(hostHash) != hostHashToPidHashes.end()) {
auto& pidHashSet = hostHashToPidHashes[hostHash];
if (pidHashSet.find(pidHash) != pidHashSet.end()) {
return false;
}
}
hostHashToPidHashes[hostHash].insert(pidHash);
}
// MSCCL is always compatible now. No need to guard against multi-thread.
return true;
}
@@ -86,8 +73,8 @@ static const char* mscclAlgoShareDirPath = "../share/rccl/msccl-algorithms";
static const char* mscclUnitTestAlgoShareDirPath = "../share/rccl/msccl-unit-test-algorithms";
static ncclResult_t mscclInternalSchedulerInit(ncclComm_t comm, int* numChannelsRequired) {
static bool mscclAlgoMetaLoaded = false;
mscclStatus& status = mscclGetStatus();
static thread_local bool mscclAlgoMetaLoaded = false;
mscclStatus& status = mscclGetStatus(comm->rank);
*numChannelsRequired = 0;
// Query numChannelsRequired from loaded algorithm metas
@@ -166,9 +153,7 @@ ncclResult_t mscclSchedulerInit(ncclComm_t comm, int* numChannelsRequired) {
return ncclSuccess;
}
std::lock_guard<std::mutex> lock(mscclLifecycleMutex);
mscclStatus& status = mscclGetStatus();
mscclStatus& status = mscclGetStatus(comm->rank);
bool useInternalScheduler = false;
const char* mscclSchedulerPath = getenv(mscclSchedulerPathEnv);
@@ -199,20 +184,11 @@ ncclResult_t mscclSchedulerInit(ncclComm_t comm, int* numChannelsRequired) {
}
ncclResult_t mscclInit(ncclComm_t comm) {
// Always initialize thread local status
mscclThreadLocalStatus threadLocalStatus = mscclGetThreadLocalStatus();
threadLocalStatus.groupStatus = mscclNoGroup;
threadLocalStatus.groupDepth = 0;
threadLocalStatus.captureId = ULLONG_MAX;
threadLocalStatus.captureStatus = mscclNoCapture;
{
std::lock_guard<std::mutex> lock(mscclLifecycleMutex);
mscclStatus& status = mscclGetStatus();
mscclStatus& status = mscclGetStatus(comm->rank);
// freeAlgoHandles and needsProxy are initialized globally once and before algorithm pre-processing and connection
if (!mscclInitialized.load(std::memory_order_acquire)) {
if (!mscclInitialized(comm->rank)) {
status.freeAlgoHandles.resize(MSCCL_MAX_NUM_ALGOS);
for (int i = 0; i < MSCCL_MAX_NUM_ALGOS; i++) {
status.freeAlgoHandles[i] = MSCCL_MAX_NUM_ALGOS - i - 1;
@@ -229,6 +205,8 @@ ncclResult_t mscclInit(ncclComm_t comm) {
if (m.nRanks == comm->nRanks) {
// Load algorithms
if (status.rankToAlgoHandles[i].find(comm->rank) == status.rankToAlgoHandles[i].end()) {
static std::mutex loadAlgoMutex;
std::lock_guard<std::mutex> lock(loadAlgoMutex);
NCCLCHECK(mscclLoadAlgo(m.filePath.c_str(), &(status.rankToAlgoHandles[i][comm->rank]), comm->rank));
}
// Connect algorithms
@@ -241,7 +219,7 @@ ncclResult_t mscclInit(ncclComm_t comm) {
}
}
if (mscclInitialized.load(std::memory_order_acquire)) {
if (mscclInitialized(comm->rank)) {
return ncclSuccess;
}
@@ -250,7 +228,7 @@ ncclResult_t mscclInit(ncclComm_t comm) {
status.lastStream = nullptr;
NCCLCHECK(mscclInitWorkFifoStatus(&(status.defaultWorkFifoStatus)));
mscclInitialized.store(true, std::memory_order_release);
mscclSetInitialized(comm->rank);
}
INFO(NCCL_INIT, "MSCCL: Initialization finished, localSize %ld", mscclKernMaxLocalSize());
@@ -266,8 +244,8 @@ ncclResult_t mscclGroupStart() {
return ncclSuccess;
}
static ncclResult_t mscclInternalSchedulerSelectAlgo(struct mscclSchedulerParam* param) {
mscclStatus& status = mscclGetStatus();
static ncclResult_t mscclInternalSchedulerSelectAlgo(int rank, struct mscclSchedulerParam* param) {
mscclStatus& status = mscclGetStatus(rank);
param->scheduled = false;
// Current MSCCL doesn't support pre/post op
@@ -312,13 +290,13 @@ static ncclResult_t mscclInternalSchedulerSelectAlgo(struct mscclSchedulerParam*
}
static ncclResult_t mscclSchedulerSelectAlgo(struct mscclSavedSchedulerParam* param) {
mscclStatus& status = mscclGetStatus();
mscclStatus& status = mscclGetStatus(param->comm->rank);
if (status.mscclSchedulerPtr) {
NCCLCHECK(status.mscclSchedulerPtr->selectAlgo(&(param->p)));
} else {
// Disable MSCCL algorithms if machine type is not matching
if (param->comm->topo->mscclEnabled || mscclForceEnabled()) {
NCCLCHECK(mscclInternalSchedulerSelectAlgo(&(param->p)));
NCCLCHECK(mscclInternalSchedulerSelectAlgo(param->comm->rank, &(param->p)));
} else {
param->p.scheduled = false;
}
@@ -513,12 +491,30 @@ ncclResult_t mscclGroupEnd() {
return ncclSuccess;
}
static ncclResult_t mscclInternalSchedulerTeardown() {
static ncclResult_t mscclInternalUnloadAlgo(int rank, mscclAlgoHandle_t mscclAlgoHandle) {
mscclStatus& status = mscclGetStatus(rank);
free(status.hostAlgos[mscclAlgoHandle]);
status.hostAlgos.erase(mscclAlgoHandle);
NCCLCHECK(ncclCudaFree(status.devAlgos[mscclAlgoHandle]));
status.devAlgos.erase(mscclAlgoHandle);
status.freeAlgoHandles.push_back(mscclAlgoHandle);
for (auto &s : status.connectedAlgos) {
s.second.erase(mscclAlgoHandle);
}
return ncclSuccess;
}
static ncclResult_t mscclInternalSchedulerTeardown(int rank) {
ncclResult_t ret = ncclSuccess, tmpRet = ncclSuccess;
mscclStatus& status = mscclGetStatus();
mscclStatus& status = mscclGetStatus(rank);
for (auto &m : status.rankToAlgoHandles) {
for (auto &p : m) {
tmpRet = mscclUnloadAlgo(p.second);
tmpRet = mscclInternalUnloadAlgo(rank, p.second);
if (ret == ncclSuccess) {
ret = tmpRet;
}
@@ -529,18 +525,13 @@ static ncclResult_t mscclInternalSchedulerTeardown() {
return ret;
}
ncclResult_t mscclTeardown() {
// Always teardown thread local status
mscclThreadLocalStatus threadLocalStatus = mscclGetThreadLocalStatus();
threadLocalStatus.savedSchedulerParams.clear();
ncclResult_t mscclTeardown(int rank) {
{
std::lock_guard<std::mutex> lock(mscclLifecycleMutex);
if (!mscclInitialized.load(std::memory_order_acquire)) {
if (!mscclInitialized(rank)) {
mscclRemoveRank(rank);
return ncclSuccess;
}
mscclStatus& status = mscclGetStatus();
mscclStatus& status = mscclGetStatus(rank);
for (auto &p : status.hostAlgos) {
free(p.second);
status.freeAlgoHandles.push_back(p.first);
@@ -563,13 +554,14 @@ ncclResult_t mscclTeardown() {
dlclose(status.mscclSchedulerLib);
status.mscclSchedulerLib = nullptr;
} else {
NCCLCHECK(mscclInternalSchedulerTeardown());
NCCLCHECK(mscclInternalSchedulerTeardown(rank));
}
NCCLCHECK(mscclDestroyWorkFifoStatus(&(status.defaultWorkFifoStatus)));
for (auto &p : status.graphWorkFifoStatus) {
NCCLCHECK(mscclDestroyWorkFifoStatus(&(p.second)));
}
mscclInitialized.store(false, std::memory_order_release);
mscclSetInitialized(rank, false);
mscclRemoveRank(rank);
}
INFO(NCCL_INIT, "MSCCL: Teardown finished");
+13 -14
Féach ar an gComhad
@@ -22,10 +22,10 @@ static inline size_t computeSizeNeeded(size_t nBytes, int nScratchChunks, int nC
return (nBytes * (size_t)nScratchChunks) / (size_t)nChunksPerLoop;
}
ncclResult_t mscclGetCaptureStatus(hipStream_t stream) {
mscclStatus& status = mscclGetStatus();
ncclResult_t mscclGetCaptureStatus(int rank, hipStream_t stream) {
mscclStatus& status = mscclGetStatus(rank);
mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus();
mscclSavedProxyArgs& savedProxyArgs = mscclGetSavedProxyArgs();
mscclSavedProxyArgs& savedProxyArgs = mscclGetSavedProxyArgs(rank);
cudaStreamCaptureStatus captureStatus;
unsigned long long captureId;
CUDACHECK(hipStreamGetCaptureInfo_v2(stream, &captureStatus, &captureId, &threadLocalStatus.graph, nullptr, nullptr));
@@ -42,12 +42,12 @@ ncclResult_t mscclGetCaptureStatus(hipStream_t stream) {
} else {
threadLocalStatus.captureStatus = mscclNoCapture;
}
INFO(NCCL_NET,"mscclGetCaptureStatus: %d, captureId: %llu, size: %lu\n", threadLocalStatus.captureStatus, threadLocalStatus.captureId, mscclGetSavedProxyArgs()[captureId].size());
INFO(NCCL_NET,"mscclGetCaptureStatus: %d, captureId: %llu, size: %lu\n", threadLocalStatus.captureStatus, threadLocalStatus.captureId, savedProxyArgs[captureId].size());
return ncclSuccess;
}
ncclResult_t mscclSetupCount(struct mscclAlgo* hostAlgo, ncclComm_t comm, size_t count, ncclDataType_t dataType) {
mscclStatus& status = mscclGetStatus();
mscclStatus& status = mscclGetStatus(comm->rank);
status.stepSize = comm->buffSizes[hostAlgo->protocol] / NCCL_STEPS;
status.chunkSteps = hostAlgo->protocol == NCCL_PROTO_SIMPLE ? hostAlgo->chunkSteps : 1;
status.sliceSteps = hostAlgo->protocol == NCCL_PROTO_SIMPLE ? hostAlgo->sliceSteps : 1;
@@ -69,12 +69,11 @@ ncclResult_t mscclSetupCount(struct mscclAlgo* hostAlgo, ncclComm_t comm, size_t
}
ncclResult_t mscclSetupScratch(struct mscclAlgo* hostAlgo, hipStream_t stream) {
mscclStatus& status = mscclGetStatus();
return ncclSuccess;
}
ncclResult_t mscclSetupSyncFlags(hipStream_t stream) {
mscclStatus& status = mscclGetStatus();
ncclResult_t mscclSetupSyncFlags(int rank, hipStream_t stream) {
mscclStatus& status = mscclGetStatus(rank);
mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus();
if (threadLocalStatus.captureStatus == mscclNewCapture ||
status.workIndex > (1ULL << (8*sizeof(status.workIndex))) - 2 * NCCL_MAX_OPS - 1) {
@@ -86,7 +85,7 @@ ncclResult_t mscclSetupSyncFlags(hipStream_t stream) {
}
ncclResult_t mscclSetupConnections(struct mscclAlgo* hostAlgo, ncclComm_t comm) {
mscclStatus& status = mscclGetStatus();
mscclStatus& status = mscclGetStatus(comm->rank);
// Check whether there are enough channels
if (hostAlgo->nChannels > comm->nChannels) {
@@ -124,7 +123,7 @@ ncclResult_t mscclSetupConnections(struct mscclAlgo* hostAlgo, ncclComm_t comm)
}
static ncclResult_t mscclSetupProxyImpl(struct mscclAlgo* hostAlgo, ncclComm_t comm) {
mscclStatus& status = mscclGetStatus();
mscclStatus& status = mscclGetStatus(comm->rank);
mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus();
struct ncclProxyOp proxyOp = {};
proxyOp.connIndex = 0;
@@ -185,9 +184,9 @@ static void HIPRT_CB mscclSetupProxyCallback(void *args) {
}
ncclResult_t mscclSetupProxy(struct mscclAlgo* hostAlgo, ncclComm_t comm, hipStream_t stream) {
mscclStatus& status = mscclGetStatus();
mscclStatus& status = mscclGetStatus(comm->rank);
mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus();
mscclSavedProxyArgs& savedProxyArgs = mscclGetSavedProxyArgs();
mscclSavedProxyArgs& savedProxyArgs = mscclGetSavedProxyArgs(comm->rank);
if (threadLocalStatus.captureStatus == mscclNoCapture) {
INFO(NCCL_NET,"mscclSetupProxy: no capture\n");
NCCLCHECK(mscclSetupProxyImpl(hostAlgo, comm));
@@ -203,7 +202,7 @@ ncclResult_t mscclSetupProxy(struct mscclAlgo* hostAlgo, ncclComm_t comm, hipStr
p.userData = params;
CUDACHECK(hipGraphAddHostNode(&callbackNode, threadLocalStatus.graph, nullptr, 0, &p));
}
mscclGetSavedProxyArgs()[threadLocalStatus.captureId].emplace_back(hostAlgo, comm);
savedProxyArgs[threadLocalStatus.captureId].emplace_back(hostAlgo, comm);
}
return ncclSuccess;
}
@@ -410,7 +409,7 @@ RCCL_PARAM(MscclForceFullOps, "MSCCL_FORCE_FULLOPS", 0);
ncclResult_t mscclSetupKernel(const void* sendBuff, void* recvBuff, size_t count,
ncclDataType_t dataType, ncclRedOp_t op, struct mscclAlgo* hostAlgo, struct mscclAlgo* devAlgo,
ncclComm_t comm, hipStream_t stream) {
mscclStatus& status = mscclGetStatus();
mscclStatus& status = mscclGetStatus(comm->rank);
mscclThreadLocalStatus& threadLocalStatus = mscclGetThreadLocalStatus();
if (status.lastStream != stream && status.lastStream != nullptr) {
+56 -6
Féach ar an gComhad
@@ -6,9 +6,60 @@
#include "msccl/msccl_status.h"
#include "msccl/msccl_struct.h"
mscclStatus& mscclGetStatus() {
static mscclStatus status;
return status;
#include "debug.h"
#include <memory>
#include <mutex>
#include <unordered_map>
using namespace std;
struct mscclRankState {
int rank;
bool initialized;
mscclStatus status;
mscclSavedProxyArgs savedProxyArgs;
mscclRankState() : rank(-1), initialized(false), status(), savedProxyArgs() {}
explicit mscclRankState(const mscclRankState&) = default;
};
static mutex rankStatesMutex;
static unordered_map<int, shared_ptr<mscclRankState>> rankStates;
static inline mscclRankState& mscclGetRankState(int rank) {
static thread_local shared_ptr<mscclRankState> threadRankState = make_shared<mscclRankState>();
if (rank < 0) {
return *threadRankState;
}
lock_guard<mutex> lock(rankStatesMutex);
auto rankStateIt = rankStates.find(rank);
if (rankStateIt == rankStates.end()) {
rankStateIt = rankStates.insert(make_pair(rank, make_shared<mscclRankState>(*threadRankState))).first;
rankStateIt->second->rank = rank;
}
return *(rankStateIt->second);
}
bool mscclInitialized(int rank) {
return mscclGetRankState(rank).initialized;
}
void mscclSetInitialized(int rank, bool initialized) {
auto& state = mscclGetRankState(rank);
assert(!initialized || !state.initialized);
state.initialized = initialized;
}
void mscclRemoveRank(int rank) {
lock_guard<mutex> lock(rankStatesMutex);
rankStates.erase(rank);
}
mscclStatus& mscclGetStatus(int rank) {
return mscclGetRankState(rank).status;
}
mscclThreadLocalStatus& mscclGetThreadLocalStatus() {
@@ -16,7 +67,6 @@ mscclThreadLocalStatus& mscclGetThreadLocalStatus() {
return threadLocalStatus;
}
mscclSavedProxyArgs& mscclGetSavedProxyArgs() {
static mscclSavedProxyArgs savedProxyArgs;
return savedProxyArgs;
mscclSavedProxyArgs& mscclGetSavedProxyArgs(int rank) {
return mscclGetRankState(rank).savedProxyArgs;
}
+5 -18
Féach ar an gComhad
@@ -12,7 +12,7 @@
NCCL_API(ncclResult_t, mscclLoadAlgo, const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, int rank);
ncclResult_t mscclLoadAlgo(const char *mscclAlgoFilePath, mscclAlgoHandle_t *mscclAlgoHandle, int rank) {
mscclStatus& status = mscclGetStatus();
mscclStatus& status = mscclGetStatus(rank);
if (status.freeAlgoHandles.size() == 0) {
WARN("MSCCL: MSCCL_MAX_NUM_ALGOS (%d) limit reached", MSCCL_MAX_NUM_ALGOS);
@@ -56,17 +56,17 @@ ncclResult_t mscclRunAlgo(
NvtxParamsMsccl payload{sendCounts[comm->rank] * ncclTypeSize(dataType), recvCounts[comm->rank] * ncclTypeSize(dataType)};
NVTX3_FUNC_WITH_PARAMS(MSCCL, MscclSchema, payload)
mscclStatus& status = mscclGetStatus();
mscclStatus& status = mscclGetStatus(comm->rank);
struct mscclAlgo* hostAlgo = status.hostAlgos[mscclAlgoHandle];
struct mscclAlgo* devAlgo = status.devAlgos[mscclAlgoHandle];
NCCLCHECK(mscclGetCaptureStatus(stream));
NCCLCHECK(mscclGetCaptureStatus(comm->rank, stream));
NCCLCHECK(mscclSetupCount(hostAlgo, comm, count, dataType));
NCCLCHECK(mscclSetupScratch(hostAlgo, stream));
NCCLCHECK(mscclSetupSyncFlags(stream));
NCCLCHECK(mscclSetupSyncFlags(comm->rank, stream));
NCCLCHECK(mscclSetupProxy(hostAlgo, comm, stream));
@@ -77,19 +77,6 @@ ncclResult_t mscclRunAlgo(
NCCL_API(ncclResult_t, mscclUnloadAlgo, mscclAlgoHandle_t mscclAlgoHandle);
ncclResult_t mscclUnloadAlgo(mscclAlgoHandle_t mscclAlgoHandle) {
mscclStatus& status = mscclGetStatus();
free(status.hostAlgos[mscclAlgoHandle]);
status.hostAlgos.erase(mscclAlgoHandle);
NCCLCHECK(ncclCudaFree(status.devAlgos[mscclAlgoHandle]));
status.devAlgos.erase(mscclAlgoHandle);
status.freeAlgoHandles.push_back(mscclAlgoHandle);
for (auto &s : status.connectedAlgos) {
s.second.erase(mscclAlgoHandle);
}
// deprecated
return ncclSuccess;
}
+1
Féach ar an gComhad
@@ -766,6 +766,7 @@ ncclResult_t pmscclRunAlgo(
/*! @endcond */
/*! @brief MSCCL Unload Algorithm
@deprecated This function has been removed from the public API.
@details Unload MSCCL algorithm previous loaded using its handle. This API
is expected to be called by MSCCL scheduler instead of end users.
@return Result code. See @ref rccl_result_code for more details.
+17
Féach ar an gComhad
@@ -5,6 +5,8 @@ cmake_minimum_required(VERSION 2.8.12)
if(BUILD_TESTS)
option(OPENMP_TESTS_ENABLED "Enable OpenMP for unit tests" OFF)
message("Building rccl unit tests (Installed in /test/rccl-UnitTests)")
find_package(hsa-runtime64 PATHS /opt/rocm )
@@ -23,6 +25,10 @@ if(BUILD_TESTS)
find_library(ROCR_LIB ${CORE_RUNTIME_TARGET} PATHS ${ROCR_LIB_DIR} "/opt/rocm" PATH_SUFFIXES lib lib64 REQUIRED)
endif()
if(OPENMP_TESTS_ENABLED)
find_package(OpenMP REQUIRED)
endif()
include_directories(${GTEST_INCLUDE_DIRS} ./common)
# Collect testing framework source files
@@ -60,12 +66,23 @@ if(BUILD_TESTS)
if(LL128_ENABLED)
target_compile_definitions(rccl-UnitTests PRIVATE ENABLE_LL128)
endif()
if(OPENMP_TESTS_ENABLED)
target_compile_definitions(rccl-UnitTests PRIVATE ENABLE_OPENMP)
endif()
target_compile_definitions(rccl-UnitTests PRIVATE ROCM_PATH="${ROCM_PATH}")
## Set rccl-UnitTests compile definitions
if(OPENMP_TESTS_ENABLED)
target_compile_options(rccl-UnitTests PRIVATE "${OpenMP_CXX_FLAGS}")
endif()
## Set rccl-UnitTests linked libraries
target_link_libraries(rccl-UnitTests PRIVATE ${GTEST_BOTH_LIBRARIES})
target_link_libraries(rccl-UnitTests PRIVATE hip::host hip::device hsa-runtime64::hsa-runtime64)
target_link_libraries(rccl-UnitTests PRIVATE Threads::Threads)
if(OPENMP_TESTS_ENABLED)
target_link_libraries(rccl-UnitTests PRIVATE "${OpenMP_CXX_FLAGS}")
endif()
# rccl-UnitTests using static library of rccl requires passing rccl
# through -l and -L instead of command line input.
+3 -1
Féach ar an gComhad
@@ -108,6 +108,7 @@ namespace RcclUnitTesting
showTiming = GetEnvVar("UT_SHOW_TIMING", 1);
useInteractive = GetEnvVar("UT_INTERACTIVE", 0);
timeoutUs = GetEnvVar("UT_TIMEOUT_US" , 5000000);
useMultithreading = GetEnvVar("UT_MULTITHREAD", false);
// Total number of reduction ops
int numOps = ncclNumOps;
@@ -232,7 +233,8 @@ namespace RcclUnitTesting
std::make_tuple("UT_PRINT_VALUES" , printValues , "Print array values (-1 for all)"),
std::make_tuple("UT_SHOW_TIMING" , showTiming , "Show timing table"),
std::make_tuple("UT_INTERACTIVE" , useInteractive, "Run in interactive mode"),
std::make_tuple("UT_TIMEOUT_US" , timeoutUs , "Timeout limit for collective calls in us")
std::make_tuple("UT_TIMEOUT_US" , timeoutUs , "Timeout limit for collective calls in us"),
std::make_tuple("UT_MULTITHREAD" , useMultithreading, "Multi-thread single-process ranks"),
};
printf("================================================================================\n");
+1
Féach ar an gComhad
@@ -29,6 +29,7 @@ namespace RcclUnitTesting
bool showTiming; // Show timing per case at end [UT_SHOW_TIMING]
bool useInteractive; // Run in interactive mode [UT_INTERACTIVE]
int timeoutUs; // Set timeout for child in microseconds [UT_TIMEOUT_US]
bool useMultithreading; // Multi-thread single-process ranks [UT_MULTITHREAD]
bool isGfx94; // Detects if architecture is gfx94
// Constructor that parses and collects environment variables
+25 -9
Féach ar an gComhad
@@ -8,7 +8,7 @@
namespace RcclUnitTesting
{
typedef enum
typedef enum : int
{
TEST_SUCCESS = 0,
TEST_FAIL = 1,
@@ -17,25 +17,41 @@ namespace RcclUnitTesting
#define ERROR(...) printf("\033[0;31m" "[ ERROR ] " "\033[0m" __VA_ARGS__)
#define INFO(...) printf("[ INFO ] " __VA_ARGS__)
#define WARN(...) printf("[ WARNING ] " __VA_ARGS__)
#define RETURN_RESULT(result) return (result)
#define CHECK_CALL(func) \
{ \
#define CHECK_CALL_BASE(func, RESULT, RESULT_ARGS...) \
do { \
ErrCode status = func; \
if (status != TEST_SUCCESS) \
{ \
ERROR("Error in call %s\n", #func); \
return status; \
RESULT(status, ##RESULT_ARGS); \
} \
}
} while (false)
#define CHECK_CALL(func) CHECK_CALL_BASE(func, RETURN_RESULT)
#define CHECK_HIP(func) \
{ \
#define CHECK_HIP_BASE(func, RESULT, RESULT_ARGS...) \
do { \
hipError_t error = (func); \
if (error != hipSuccess) \
{ \
fprintf(stderr, "\033[0;31m" "[ ERROR ] HIP error: %s File:%s Line:%d\n" "\033[m", \
hipGetErrorString(error), strrchr("/" __FILE__, '/') + 1, __LINE__); \
return TEST_FAIL; \
RESULT(TEST_FAIL, ##RESULT_ARGS); \
} \
}
} while (false)
#define CHECK_HIP(func) CHECK_HIP_BASE(func, RETURN_RESULT)
#ifdef ENABLE_OPENMP
#define OMP_CANCEL_FOR(result, errCode) errCode = (result); _Pragma("omp cancel for")
#define RANK_RESULT(errCode, result) OMP_CANCEL_FOR(result, errCode)
#define CHECK_CALL_RANK(errCode, func) CHECK_CALL_BASE(func, OMP_CANCEL_FOR, errCode)
#define CHECK_HIP_RANK(errCode, func) CHECK_HIP_BASE(func, OMP_CANCEL_FOR, errCode)
#else
#define RANK_RESULT(errCode, result) RETURN_RESULT(result)
#define CHECK_CALL_RANK(errCode, func) CHECK_CALL(func)
#define CHECK_HIP_RANK(errCode, func) CHECK_HIP(func)
#endif
}
+1 -1
Féach ar an gComhad
@@ -97,7 +97,7 @@ namespace RcclUnitTesting
childList.resize(this->numActiveChildren);
for (int childId = 0; childId < this->numActiveChildren; ++childId)
{
childList[childId] = new TestBedChild(childId, ev.verbose, ev.printValues);
childList[childId] = new TestBedChild(childId, ev.verbose, ev.printValues, ev.useMultithreading);
if (childList[childId]->InitPipes() != TEST_SUCCESS)
{
ERROR("Unable to create pipes to child process\n");
+76 -25
Féach ar an gComhad
@@ -8,20 +8,33 @@
#include <thread>
#include <execinfo.h>
#ifdef ENABLE_OPENMP
#include <omp.h>
#endif
#define CHILD_NCCL_CALL(cmd, msg) \
{ \
static int getThreadId()
{
#ifdef ENABLE_OPENMP
return (int)omp_get_thread_num();
#else
return -1;
#endif
}
#define CHILD_NCCL_CALL_BASE(cmd, msg, RESULT, RESULT_ARGS...) \
do { \
if (this->verbose) printf("[ NCCL CALL] " #cmd "\n"); \
ncclResult_t status = cmd; \
if (status != ncclSuccess) \
{ \
ERROR("Child process %d fails NCCL call %s with code %d\n", this->childId, msg, status); \
return TEST_FAIL; \
RESULT(TEST_FAIL, ##RESULT_ARGS); \
} \
}
} while (false)
#define CHILD_NCCL_CALL(cmd, msg) CHILD_NCCL_CALL_BASE(cmd, msg, RETURN_RESULT)
#define CHILD_NCCL_CALL_NON_BLOCKING(msg, localRank) \
{ \
#define CHILD_NCCL_CALL_NON_BLOCKING_BASE(msg, localRank, RESULT, RESULT_ARGS...) \
do { \
unsigned long int loop_counter = 0; \
ncclResult_t ncclAsyncErr; \
loop_counter = 0; \
@@ -34,20 +47,30 @@
if (ncclAsyncErr != ncclSuccess) \
{ \
ERROR("Child process %d fails NCCL call %s with code %d\n", this->childId, msg, ncclAsyncErr); \
return TEST_FAIL; \
RESULT(TEST_FAIL, ##RESULT_ARGS); \
} \
}
} while (false)
#define CHILD_NCCL_CALL_NON_BLOCKING(msg, localRank) CHILD_NCCL_CALL_NON_BLOCKING_BASE(msg, localRank, RETURN_RESULT)
#define PIPE_READ(val) \
if (read(childReadFd, &val, sizeof(val)) != sizeof(val)) return TEST_FAIL;
#ifdef ENABLE_OPENMP
#define CHILD_NCCL_CALL_RANK(errCode, cmd, msg) CHILD_NCCL_CALL_BASE(cmd, msg, OMP_CANCEL_FOR, errCode)
#define CHILD_NCCL_CALL_NON_BLOCKING_RANK(errCode, msg, localRank) CHILD_NCCL_CALL_NON_BLOCKING_BASE(msg, localRank, OMP_CANCEL_FOR, errCode)
#else
#define CHILD_NCCL_CALL_RANK(errCode, cmd, msg) CHILD_NCCL_CALL(cmd, msg)
#define CHILD_NCCL_CALL_NON_BLOCKING_RANK(errCode, msg, localRank) CHILD_NCCL_CALL_NON_BLOCKING(msg, localRank)
#endif
namespace RcclUnitTesting
{
TestBedChild::TestBedChild(int const childId, bool const verbose, int const printValues)
TestBedChild::TestBedChild(int const childId, bool const verbose, int const printValues, bool const useRankThreading)
{
this->childId = childId;
this->verbose = verbose;
this->printValues = printValues;
this->useRankThreading = useRankThreading;
}
int TestBedChild::InitPipes()
@@ -83,6 +106,9 @@ namespace RcclUnitTesting
// Wait for commands from parent process
if (verbose) INFO("Child %d enters execution loop\n", this->childId);
#ifndef ENABLE_OPENMP
if (verbose && useRankThreading) WARN("Multi-threaded ranks requires ENABLE_OPENMP to be defined\n");
#endif
int command;
while (read(childReadFd, &command, sizeof(command)) > 0)
{
@@ -473,6 +499,8 @@ namespace RcclUnitTesting
}
}
int numThreadsToUse = this->useRankThreading ? numRanksToExecute : 1;
// Start group call
CHILD_NCCL_CALL(ncclGroupStart(), "ncclGroupStart");
@@ -480,9 +508,17 @@ namespace RcclUnitTesting
for (int collId = 0; collId < this->numCollectivesInGroup[groupId]; ++collId)
{
// Loop over all local ranks
if (this->verbose && this->useRankThreading)
INFO("Group %d collective %d running %d threads\n", groupId, collId, numThreadsToUse);
ErrCode errCode = TEST_SUCCESS;
auto& errCodeVal = reinterpret_cast<int&>(errCode);
#pragma omp parallel for num_threads(numThreadsToUse) reduction(max : errCodeVal)
for (int localRank : localRanksToExecute)
{
CHECK_HIP(hipSetDevice(this->deviceIds[localRank]));
if (this->verbose && this->useRankThreading)
INFO("Group %d collective %d running rank %d on thread %d\n", groupId, collId, localRank, getThreadId());
CHECK_HIP_RANK(errCode, hipSetDevice(this->deviceIds[localRank]));
CollectiveArgs const& collArg = this->collArgs[groupId][localRank][collId];
@@ -492,14 +528,14 @@ namespace RcclUnitTesting
PtrUnion inputCpu;
size_t const numInputBytes = numInputElementsToPrint * DataTypeToBytes(collArg.dataType);
inputCpu.AllocateCpuMem(numInputBytes);
CHECK_HIP(hipMemcpy(inputCpu.ptr, collArg.inputGpu.ptr, numInputBytes, hipMemcpyDeviceToHost));
CHECK_HIP_RANK(errCode, hipMemcpy(inputCpu.ptr, collArg.inputGpu.ptr, numInputBytes, hipMemcpyDeviceToHost));
printf("[ DEBUG ] Rank %02d Group %d Coll %d %-10s: %s\n", collArg.globalRank, groupId, collId, "Input",
inputCpu.ToString(collArg.dataType, numInputElementsToPrint).c_str());
inputCpu.FreeCpuMem();
int const numOutputElementsToPrint = (this->printValues < 0 ? collArg.numOutputElements : this->printValues);
size_t const numOutputBytes = numOutputElementsToPrint * DataTypeToBytes(collArg.dataType);
CHECK_HIP(hipMemcpy(collArg.outputCpu.ptr, collArg.outputGpu.ptr, numOutputBytes, hipMemcpyDeviceToHost));
CHECK_HIP_RANK(errCode, hipMemcpy(collArg.outputCpu.ptr, collArg.outputGpu.ptr, numOutputBytes, hipMemcpyDeviceToHost));
printf("[ DEBUG ] Rank %02d Group %d Coll %d %-10s: %s\n", collArg.globalRank, groupId, collId, "Pre-Output",
collArg.outputCpu.ToString(collArg.dataType, numOutputElementsToPrint).c_str());
}
@@ -507,7 +543,8 @@ namespace RcclUnitTesting
switch (collArg.funcType)
{
case ncclCollBroadcast:
CHILD_NCCL_CALL(ncclBroadcast(collArg.inputGpu.ptr,
CHILD_NCCL_CALL_RANK(errCode, ncclBroadcast(
collArg.inputGpu.ptr,
collArg.outputGpu.ptr,
collArg.numInputElements,
collArg.dataType,
@@ -517,7 +554,8 @@ namespace RcclUnitTesting
"ncclBroadcast");
break;
case ncclCollReduce:
CHILD_NCCL_CALL(ncclReduce(collArg.inputGpu.ptr,
CHILD_NCCL_CALL_RANK(errCode, ncclReduce(
collArg.inputGpu.ptr,
collArg.outputGpu.ptr,
collArg.numInputElements,
collArg.dataType,
@@ -528,7 +566,8 @@ namespace RcclUnitTesting
"ncclReduce");
break;
case ncclCollAllGather:
CHILD_NCCL_CALL(ncclAllGather(collArg.inputGpu.ptr,
CHILD_NCCL_CALL_RANK(errCode, ncclAllGather(
collArg.inputGpu.ptr,
collArg.outputGpu.ptr,
collArg.numInputElements,
collArg.dataType,
@@ -537,7 +576,8 @@ namespace RcclUnitTesting
"ncclAllGather");
break;
case ncclCollReduceScatter:
CHILD_NCCL_CALL(ncclReduceScatter(collArg.inputGpu.ptr,
CHILD_NCCL_CALL_RANK(errCode, ncclReduceScatter(
collArg.inputGpu.ptr,
collArg.outputGpu.ptr,
collArg.numOutputElements,
collArg.dataType,
@@ -547,7 +587,8 @@ namespace RcclUnitTesting
"ncclReduceScatter");
break;
case ncclCollAllReduce:
CHILD_NCCL_CALL(ncclAllReduce(collArg.inputGpu.ptr,
CHILD_NCCL_CALL_RANK(errCode, ncclAllReduce(
collArg.inputGpu.ptr,
collArg.outputGpu.ptr,
collArg.numInputElements,
collArg.dataType,
@@ -557,7 +598,8 @@ namespace RcclUnitTesting
"ncclAllReduce");
break;
case ncclCollGather:
CHILD_NCCL_CALL(ncclGather(collArg.inputGpu.ptr,
CHILD_NCCL_CALL_RANK(errCode, ncclGather(
collArg.inputGpu.ptr,
collArg.outputGpu.ptr,
collArg.numInputElements,
collArg.dataType,
@@ -567,7 +609,8 @@ namespace RcclUnitTesting
"ncclGather");
break;
case ncclCollScatter:
CHILD_NCCL_CALL(ncclScatter(collArg.inputGpu.ptr,
CHILD_NCCL_CALL_RANK(errCode, ncclScatter(
collArg.inputGpu.ptr,
collArg.outputGpu.ptr,
collArg.numOutputElements,
collArg.dataType,
@@ -577,7 +620,8 @@ namespace RcclUnitTesting
"ncclScatter");
break;
case ncclCollAllToAll:
CHILD_NCCL_CALL(ncclAllToAll(collArg.inputGpu.ptr,
CHILD_NCCL_CALL_RANK(errCode, ncclAllToAll(
collArg.inputGpu.ptr,
collArg.outputGpu.ptr,
collArg.numInputElements / collArg.totalRanks,
collArg.dataType,
@@ -586,7 +630,8 @@ namespace RcclUnitTesting
"ncclAllToAll");
break;
case ncclCollAllToAllv:
CHILD_NCCL_CALL(ncclAllToAllv(collArg.inputGpu.ptr,
CHILD_NCCL_CALL_RANK(errCode, ncclAllToAllv(
collArg.inputGpu.ptr,
collArg.options.sendcounts + (this->rankOffset + localRank)*this->totalRanks,
collArg.options.sdispls + (this->rankOffset + localRank)*this->totalRanks,
collArg.outputGpu.ptr,
@@ -598,7 +643,8 @@ namespace RcclUnitTesting
"ncclAllToAllv");
break;
case ncclCollSend:
CHILD_NCCL_CALL(ncclSend(collArg.inputGpu.ptr,
CHILD_NCCL_CALL_RANK(errCode, ncclSend(
collArg.inputGpu.ptr,
collArg.numInputElements,
collArg.dataType,
collArg.options.root,
@@ -607,7 +653,8 @@ namespace RcclUnitTesting
"ncclSend");
break;
case ncclCollRecv:
CHILD_NCCL_CALL(ncclRecv(collArg.outputGpu.ptr,
CHILD_NCCL_CALL_RANK(errCode, ncclRecv(
collArg.outputGpu.ptr,
collArg.numOutputElements,
collArg.dataType,
collArg.options.root,
@@ -617,14 +664,18 @@ namespace RcclUnitTesting
break;
default:
ERROR("Unknown func type %d\n", collArg.funcType);
return TEST_FAIL;
RANK_RESULT(errCode, TEST_FAIL);
}
if (this->useBlocking == false)
{
CHILD_NCCL_CALL_NON_BLOCKING("ncclCommGetAsyncErrorExecuteCollectives", localRank);
CHILD_NCCL_CALL_NON_BLOCKING_RANK(errCode, "ncclCommGetAsyncErrorExecuteCollectives", localRank);
}
if (this->verbose && this->useRankThreading)
INFO("Group %d collective %d done rank %d on thread %d\n", groupId, collId, localRank, getThreadId());
}
if (this->useRankThreading) CHECK_CALL(errCode);
}
// End group call
if (this->useBlocking == false)
+2 -1
Féach ar an gComhad
@@ -57,6 +57,7 @@ namespace RcclUnitTesting
pid_t pid;
bool verbose;
int printValues;
bool useRankThreading;
// Pipes used to communicate between parent process
int parentWriteFd;
@@ -80,7 +81,7 @@ namespace RcclUnitTesting
std::vector<std::vector<std::vector<bool>>> graphEnabled;
// Constructor
TestBedChild(int const childId, bool const verbose, int const printValues);
TestBedChild(int const childId, bool const verbose, int const printValues, bool const useRankThreading);
// Prepare parent/child communication pipes - to be executed by parent process
int InitPipes();