diff --git a/projects/rccl/src/rccl_wrap.cc b/projects/rccl/src/rccl_wrap.cc index caf8a79488..4f7ee90a06 100644 --- a/projects/rccl/src/rccl_wrap.cc +++ b/projects/rccl/src/rccl_wrap.cc @@ -474,26 +474,25 @@ void rcclSetWarpSpeedSupportAndFinalCuCount(struct ncclComm* comm, struct ncclKe void rcclSetWarpSpeedAuto(struct ncclComm* comm, struct ncclTaskColl* info, size_t nBytes) { info->useWarpSpeed = false; - if(!comm->topo->warpSpeedEnabled) { - return; - } - info->useWarpSpeed = (info->algorithm == NCCL_ALGO_RING); // Enabled by default for any RING algorithm when platform supports it - if(rcclParamWarpSpeedAutoMode() != 0 && IsArchMatch(comm->archName, "gfx950")) { // Auto mode only available for gfx950 currently + if(rcclParamWarpSpeedAutoMode() != 0) { // auto mode + if(!IsArchMatch(comm->archName, "gfx950")) { + // Auto mode only available for gfx950 currently, keep it to false + return; + } size_t minBytes = 0; + commSetUnrollFactor(comm); // TODO: reset unroll factor per task rather than per comm if(info->func == ncclFuncAllReduce || info->func == ncclFuncAllGather) minBytes = RCCL_WARP_SPEED_MIN_BYTES; else if (info->func == ncclFuncReduceScatter) minBytes = RCCL_WARP_SPEED_MIN_BYTES << 2; // ReduceScatter requires higher message size to benefit from WarpSpeed if(comm->nNodes == 1) { if(nBytes >= minBytes && minBytes > 0) { comm->unroll = NCCL_UNROLL_2; info->nWarps = 4; + info->useWarpSpeed = true; } - } else { - // TODO: set unroll factor per task rather than per comm - commSetUnrollFactor(comm); - info->useWarpSpeed = false; } + } else if (comm->topo->warpSpeedEnabled && info->algorithm == NCCL_ALGO_RING) { + info->useWarpSpeed = true; } - } #endif