diff --git a/src/collectives.cc b/src/collectives.cc index 120734280e..e07a6eed44 100644 --- a/src/collectives.cc +++ b/src/collectives.cc @@ -92,7 +92,6 @@ ncclResult_t ncclAllGather_impl(const void* sendbuff, void* recvbuff, size_t sen ALLGATHER_CHUNKSTEPS, comm -> rcclUseOneSlice ? ALLGATHER_SLICESTEPS_SINGLE_NODE : ALLGATHER_SLICESTEPS, nullptr }; int nRanks; - const void* srcbuff; int in_place = 0; NCCLCHECK(ncclCommCount(comm, &nRanks)); size_t msgSize = sendcount * ncclTypeSize(datatype) * nRanks; @@ -112,19 +111,17 @@ ncclResult_t ncclAllGather_impl(const void* sendbuff, void* recvbuff, size_t sen // use direct allgather if (sendcount == 0) return ncclSuccess; size_t rankOffset = sendcount * ncclTypeSize(datatype); - if (((char*)recvbuff) != (((char*)sendbuff) + comm->rank * rankOffset)) { - srcbuff = sendbuff; - } else { - srcbuff = ((char*)recvbuff) + comm->rank * rankOffset; + if (((char*)sendbuff) == (((char*)recvbuff) + comm->rank * rankOffset)) { in_place = 1; - } + } + NCCLCHECK(ncclGroupStart()); for (int r = 0; r < nRanks; r++) { int peer = (comm->rank + r) % nRanks; if (in_place && (peer == comm->rank)) { continue; } - NCCLCHECK(ncclSend(((char*)srcbuff), sendcount, datatype, peer, comm, stream)); + NCCLCHECK(ncclSend(sendbuff, sendcount, datatype, peer, comm, stream)); NCCLCHECK(ncclRecv(((char*)recvbuff) + peer * rankOffset, sendcount, datatype, peer, comm, stream)); } NCCLCHECK(ncclGroupEnd()); diff --git a/src/rccl_wrap.cc b/src/rccl_wrap.cc index 1a622fdb71..2b1bf1f0a8 100644 --- a/src/rccl_wrap.cc +++ b/src/rccl_wrap.cc @@ -33,7 +33,7 @@ RCCL_PARAM(PipelineAllDTypes, "PIPELINE_ALL_DATA_TYPES", 0); // Use this to assess impact of pipelining on performance. // Otherwise, it is automatically set for certain archs, datatypes and reduction collectives RCCL_PARAM(disableReduceCopyPipelining, "DISABLE_REDUCE_COPY_PIPELINING", 0); -RCCL_PARAM(DirectAllGatherThreshold, "DIRECT_ALLGATHER_THRESHOLD", 4194304); +RCCL_PARAM(DirectAllGatherThreshold, "DIRECT_ALLGATHER_THRESHOLD", 75497472); void rcclUpdateCollectiveProtocol(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info) { // Honor user input for protocol choice @@ -42,8 +42,11 @@ void rcclUpdateCollectiveProtocol(struct ncclComm* comm, size_t const& nBytes, s const char *protoStr = getenv("NCCL_PROTO"); userProtocolInput = !protoStr ? 0 : 1; } + if (!userProtocolInput && IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx950") && comm->nNodes == 1 && (info->func == ncclFuncAllGather) && nBytes <= 524288) { + // Change LL protocol threshold + info->protocol = NCCL_PROTO_LL; - if(!userProtocolInput && comm->nNodes >= 2 && (info->func == ncclFuncReduceScatter || info->func == ncclFuncAllGather || info->func == ncclFuncAllReduce || info->func == ncclFuncBroadcast || info->func == ncclFuncReduce)) { + } else if(!userProtocolInput && comm->nNodes >= 2 && (info->func == ncclFuncReduceScatter || info->func == ncclFuncAllGather || info->func == ncclFuncAllReduce || info->func == ncclFuncBroadcast || info->func == ncclFuncReduce)) { auto tunableIndex = rcclGetTunableIndex(info->func); auto llMin = comm->minMaxLLRange[tunableIndex][NCCL_PROTO_LL][RCCL_PROTOCOL_MIN_IDX]; auto llMax = comm->minMaxLLRange[tunableIndex][NCCL_PROTO_LL][RCCL_PROTOCOL_MAX_IDX]; @@ -291,8 +294,25 @@ ncclResult_t rcclGetProtocolName(int protocol, const char** protocolName) { } bool rcclUseAllGatherDirect(struct ncclComm* comm, size_t& msgSize) { - return (comm->enableCustColl && (comm->nNodes > 1 && comm->nNodes <= 16) && (msgSize <= rcclParamDirectAllGatherThreshold() && - rcclParamDirectAllGatherThreshold() > -1)); + size_t threshold = rcclParamDirectAllGatherThreshold(); + + if (IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx950")) { + if (comm->nNodes == 1 && threshold != -1) { + threshold = 8388608; + } else if (comm->nNodes < 64 && threshold != -1) { + threshold = comm->nNodes * 2097152; + } + } else if (IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx942")) { + threshold = 4194304; + } + + comm->enableCustColl = IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx950") || IsArchMatch(comm->topo->nodes[GPU].nodes[0].gpu.gcn, "gfx942"); + + int rankMultiple = comm->nRanks % 8; + + //return (comm->enableCustColl && (comm->nNodes > 1) && (msgSize <= threshold) && (threshold != -1)) + return (comm->enableCustColl && (msgSize <= threshold) && (threshold != -1) && !rankMultiple) + ; } void rcclSetPxn(struct ncclComm* comm, int& rcclPxnDisable) { diff --git a/test/common/CollectiveArgs.cpp b/test/common/CollectiveArgs.cpp index 1235d708c8..28ce1b1c09 100644 --- a/test/common/CollectiveArgs.cpp +++ b/test/common/CollectiveArgs.cpp @@ -82,7 +82,7 @@ namespace RcclUnitTesting CHECK_CALL(this->inputGpu.AllocateGpuMem(this->numInputBytesAllocated, useManagedMem, userRegistered)); this->outputGpu.Attach(this->inputGpu.U1 + (this->globalRank * this->numOutputBytesAllocated)); } - else if (this->funcType == ncclCollGather) + else if (this->funcType == ncclCollGather || this->funcType == ncclCollAllGather) { CHECK_CALL(this->outputGpu.AllocateGpuMem(this->numOutputBytesAllocated, useManagedMem, userRegistered)); this->inputGpu.Attach(this->outputGpu.U1 + (this->globalRank * this->numInputBytesAllocated));