Add CUDA graph support only for CUDA 11.3 and later builds
Fixes #90
[ROCm/rccl-tests commit: 1f8f541686]
This commit is contained in:
@@ -534,6 +534,7 @@ testResult_t BenchTime(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t
|
||||
|
||||
Barrier(args);
|
||||
|
||||
#if CUDART_VERSION >= 11030
|
||||
cudaGraph_t graphs[args->nGpus];
|
||||
cudaGraphExec_t graphExec[args->nGpus];
|
||||
if (cudaGraphLaunches >= 1) {
|
||||
@@ -542,6 +543,7 @@ testResult_t BenchTime(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t
|
||||
CUDACHECK(cudaStreamBeginCapture(args->streams[i], args->nThreads > 1 ? cudaStreamCaptureModeThreadLocal : cudaStreamCaptureModeGlobal));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// Performance Benchmark
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
@@ -553,6 +555,7 @@ testResult_t BenchTime(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t
|
||||
if (agg_iters>1) NCCLCHECK(ncclGroupEnd());
|
||||
}
|
||||
|
||||
#if CUDART_VERSION >= 11030
|
||||
if (cudaGraphLaunches >= 1) {
|
||||
// End cuda graph capture
|
||||
for (int i=0; i<args->nGpus; i++) {
|
||||
@@ -571,6 +574,7 @@ testResult_t BenchTime(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
TESTCHECK(completeColl(args));
|
||||
|
||||
@@ -580,6 +584,7 @@ testResult_t BenchTime(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t
|
||||
if (cudaGraphLaunches >= 1) deltaSec = deltaSec/cudaGraphLaunches;
|
||||
Allreduce(args, &deltaSec, average);
|
||||
|
||||
#if CUDART_VERSION >= 11030
|
||||
if (cudaGraphLaunches >= 1) {
|
||||
//destroy cuda graph
|
||||
for (int i=0; i<args->nGpus; i++) {
|
||||
@@ -587,6 +592,7 @@ testResult_t BenchTime(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t
|
||||
CUDACHECK(cudaGraphDestroy(graphs[i]));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
double algBw, busBw;
|
||||
args->collTest->getBw(count, wordSize(type), deltaSec, &algBw, &busBw, args->nProcs*args->nThreads*args->nGpus);
|
||||
@@ -600,16 +606,19 @@ testResult_t BenchTime(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t
|
||||
// Initialize sendbuffs, recvbuffs and expected
|
||||
TESTCHECK(args->collTest->initData(args, type, op, root, rep, in_place));
|
||||
|
||||
#if CUDART_VERSION >= 11030
|
||||
if (cudaGraphLaunches >= 1) {
|
||||
// Begin cuda graph capture for data check
|
||||
for (int i=0; i<args->nGpus; i++) {
|
||||
CUDACHECK(cudaStreamBeginCapture(args->streams[i], cudaStreamCaptureModeThreadLocal));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
//test validation in single itertion, should ideally be included into the multi-iteration run
|
||||
TESTCHECK(startColl(args, type, op, root, in_place, 0));
|
||||
|
||||
#if CUDART_VERSION >= 11030
|
||||
if (cudaGraphLaunches >= 1) {
|
||||
// End cuda graph capture
|
||||
for (int i=0; i<args->nGpus; i++) {
|
||||
@@ -624,9 +633,11 @@ testResult_t BenchTime(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t
|
||||
CUDACHECK(cudaGraphLaunch(graphExec[i], args->streams[i]));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
TESTCHECK(completeColl(args));
|
||||
|
||||
#if CUDART_VERSION >= 11030
|
||||
if (cudaGraphLaunches >= 1) {
|
||||
//destroy cuda graph
|
||||
for (int i=0; i<args->nGpus; i++) {
|
||||
@@ -634,6 +645,7 @@ testResult_t BenchTime(struct threadArgs* args, ncclDataType_t type, ncclRedOp_t
|
||||
CUDACHECK(cudaGraphDestroy(graphs[i]));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
TESTCHECK(CheckData(args, type, op, root, in_place, &maxDelta));
|
||||
|
||||
|
||||
Fai riferimento in un nuovo problema
Block a user