diff --git a/test/AllToAllv_OutOfPlace.cpp b/test/AllToAllv_OutOfPlace.cpp index 083f3353cd..f072558b30 100644 --- a/test/AllToAllv_OutOfPlace.cpp +++ b/test/AllToAllv_OutOfPlace.cpp @@ -11,13 +11,18 @@ namespace RcclUnitTesting void sendRecvPrep(size_t numInputElementsArray[], size_t numOutputElementsArray[], OptionalColArgs &options, - int totalRanks, int numElementsBase) + int totalRanks, int numElements) { + int sendcount = (numElements/totalRanks)*totalRanks; + size_t chunksize = sendcount*2/(totalRanks*totalRanks); for (int sendRank = 0; sendRank < totalRanks; ++sendRank) for (int recvRank = 0; recvRank < totalRanks; ++recvRank ) { //create send counts, and build other arrays from that - options.sendcounts[sendRank*totalRanks+recvRank] = numElementsBase * (recvRank + 1); + size_t scount = ((sendRank+recvRank)%totalRanks)*chunksize; + if ((sendRank+recvRank)%totalRanks == 0) + scount += (sendcount-chunksize*(totalRanks-1)*totalRanks/2); + options.sendcounts[sendRank*totalRanks+recvRank] = scount; options.recvcounts[recvRank*totalRanks+sendRank] = options.sendcounts[sendRank*totalRanks+recvRank ]; } @@ -43,7 +48,7 @@ namespace RcclUnitTesting TestBed testBed; // Configuration std::vector const& dataTypes = {ncclInt32, ncclFloat64}; - std::vector const numElementsBase = {1048576, 53327, 1024}; + std::vector const numElements = {1048576, 53327, 1024}; bool const inPlace = false; bool const useManagedMem = false; @@ -59,7 +64,7 @@ namespace RcclUnitTesting testBed.InitComms(TestBed::GetDeviceIdsList(numProcesses, totalRanks)); for (int dataIdx = 0; dataIdx < dataTypes.size() && isCorrect; ++dataIdx) - for (int numIdx = 0; numIdx < numElementsBase.size() && isCorrect; ++numIdx) + for (int numIdx = 0; numIdx < numElements.size() && isCorrect; ++numIdx) { if (testBed.ev.showNames) { @@ -70,7 +75,7 @@ namespace RcclUnitTesting INFO("%s\n", name.c_str()); } - sendRecvPrep(numInputElementsArray, numOutputElementsArray, options, totalRanks, numElementsBase[numIdx]); + sendRecvPrep(numInputElementsArray, numOutputElementsArray, options, totalRanks, numElements[numIdx]); for (int rank = 0; rank < totalRanks; ++rank) { testBed.SetCollectiveArgs(ncclCollAllToAllv,