diff --git a/src/collectives.cc b/src/collectives.cc index 50a7a9297f..390045eccc 100644 --- a/src/collectives.cc +++ b/src/collectives.cc @@ -131,16 +131,14 @@ ncclResult_t ncclAllGather_impl(const void* sendbuff, void* recvbuff, size_t sen dstBuf = recvbuff; } - if (!in_place) - CUDACHECK(cudaMemcpyAsync((char*)dstBuf + rank * rankOffset, srcBuf, rankOffset, cudaMemcpyDeviceToDevice, stream)); - NCCLCHECK(ncclGroupStart()); for (int r = 0; r < nRanks; r++) { - if (r != rank) { - NCCLCHECK(ncclSend(((char*)dstBuf) + rank * rankOffset, sendcount, datatype, r, comm, stream)); - NCCLCHECK(ncclRecv(((char*)dstBuf) + r * rankOffset, sendcount, datatype, r, comm, stream)); - } + if (r == rank && in_place) + continue; + + NCCLCHECK(ncclSend(((char*)srcBuf), sendcount, datatype, r, comm, stream)); + NCCLCHECK(ncclRecv(((char*)dstBuf) + r * rankOffset, sendcount, datatype, r, comm, stream)); } NCCLCHECK(ncclGroupEnd()); return ncclSuccess;