From cf3ffb2f5f4c4d0ccf4f8a050e34bafc2462cfa9 Mon Sep 17 00:00:00 2001 From: David Addison Date: Thu, 25 Jul 2024 21:47:40 -0700 Subject: [PATCH] Added -N,--run_cycles option [ROCm/rccl-tests commit: d2d40cc8249378efa4d7e2c949528c15eeb7d8e7] --- projects/rccl-tests/src/common.cu | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/projects/rccl-tests/src/common.cu b/projects/rccl-tests/src/common.cu index 04e81422f0..872a18a1b6 100644 --- a/projects/rccl-tests/src/common.cu +++ b/projects/rccl-tests/src/common.cu @@ -69,6 +69,7 @@ static int datacheck = 1; static int warmup_iters = 5; static int iters = 20; static int agg_iters = 1; +static int run_cycles = 1; static int ncclop = ncclSum; static int nccltype = ncclFloat; static int ncclroot = 0; @@ -598,7 +599,9 @@ testResult_t TimeTest(struct threadArgs* args, ncclDataType_t type, const char* TESTCHECK(completeColl(args)); // Benchmark - for (size_t size = args->minbytes; size<=args->maxbytes; size = ((args->stepfactor > 1) ? size*args->stepfactor : size+args->stepbytes)) { + long repeat = run_cycles; + do { + for (size_t size = args->minbytes; size<=args->maxbytes; size = ((args->stepfactor > 1) ? size*args->stepfactor : size+args->stepbytes)) { setupArgs(size, type, args); char rootName[100]; sprintf(rootName, "%6i", root); @@ -606,7 +609,9 @@ testResult_t TimeTest(struct threadArgs* args, ncclDataType_t type, const char* TESTCHECK(BenchTime(args, type, op, root, 0)); TESTCHECK(BenchTime(args, type, op, root, 1)); PRINT("\n"); - } + } + } while (--repeat); + return testSuccess; } @@ -717,6 +722,7 @@ int main(int argc, char* argv[]) { {"iters", required_argument, 0, 'n'}, {"agg_iters", required_argument, 0, 'm'}, {"warmup_iters", required_argument, 0, 'w'}, + {"run_cycles", required_argument, 0, 'N'}, {"parallel_init", required_argument, 0, 'p'}, {"check", required_argument, 0, 'c'}, {"op", required_argument, 0, 'o'}, @@ -735,7 +741,7 @@ int main(int argc, char* argv[]) { while(1) { int c; - c = getopt_long(argc, argv, "t:g:b:e:i:f:n:m:w:p:c:o:d:r:z:y:T:hG:C:a:R:", longopts, &longindex); + c = getopt_long(argc, argv, "t:g:b:e:i:f:n:m:w:N:p:c:o:d:r:z:y:T:hG:C:a:R:", longopts, &longindex); if (c == -1) break; @@ -782,6 +788,9 @@ int main(int argc, char* argv[]) { case 'w': warmup_iters = (int)strtol(optarg, NULL, 0); break; + case 'N': + run_cycles = (int)strtol(optarg, NULL, 0); + break; case 'c': datacheck = (int)strtol(optarg, NULL, 0); break; @@ -841,6 +850,7 @@ int main(int argc, char* argv[]) { "[-n,--iters ] \n\t" "[-m,--agg_iters ] \n\t" "[-w,--warmup_iters ] \n\t" + "[-N,--run_cycles run & print each cycle (default: 1; 0=infinite)] \n\t" "[-p,--parallel_init <0/1>] \n\t" "[-c,--check ] \n\t" #if NCCL_VERSION_CODE >= NCCL_VERSION(2,11,0)