Wrap ncclCommWindowRegister() calls within ncclGroup
[ROCm/rccl-tests commit: e7c8825b0b]
This commit is contained in:
@@ -659,6 +659,7 @@ testResult_t threadInit(struct threadArgs* args) {
|
||||
}
|
||||
NCCLCHECK(ncclGroupEnd());
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,19,0)
|
||||
NCCLCHECK(ncclGroupStart());
|
||||
void **sendRegHandles = (local_register) ? (void **)malloc(sizeof(*sendRegHandles)*args->nGpus) : NULL;
|
||||
void **recvRegHandles = (local_register) ? (void **)malloc(sizeof(*recvRegHandles)*args->nGpus) : NULL;
|
||||
for (int i=0; i<args->nGpus; i++) {
|
||||
@@ -673,6 +674,7 @@ testResult_t threadInit(struct threadArgs* args) {
|
||||
if (local_register) NCCLCHECK(ncclCommRegister(args->comms[i], args->recvbuffs[i], args->maxbytes, &recvRegHandles[i]));
|
||||
}
|
||||
}
|
||||
NCCLCHECK(ncclGroupEnd());
|
||||
#endif
|
||||
|
||||
TESTCHECK(threadRunTests(args));
|
||||
@@ -1124,6 +1126,7 @@ testResult_t run() {
|
||||
NCCLCHECK(ncclGroupEnd());
|
||||
}
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2,19,0)
|
||||
NCCLCHECK(ncclGroupStart());
|
||||
sendRegHandles = (local_register) ? (void **)malloc(sizeof(*sendRegHandles)*nThreads*nGpus) : NULL;
|
||||
recvRegHandles = (local_register) ? (void **)malloc(sizeof(*recvRegHandles)*nThreads*nGpus) : NULL;
|
||||
for (int i=0; i<nGpus*nThreads; i++) {
|
||||
@@ -1138,6 +1141,7 @@ testResult_t run() {
|
||||
if (local_register) NCCLCHECK(ncclCommRegister(comms[i], recvbuffs[i], maxBytes, &recvRegHandles[i]));
|
||||
}
|
||||
}
|
||||
NCCLCHECK(ncclGroupEnd());
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user