Add fix for WarpSpeed auto mode (#2104)

[ROCm/rccl commit: 5787c960fc]
Bu işleme şunda yer alıyor:
Mustafa Abduljabbar
2025-12-12 17:56:52 -05:00
işlemeyi yapan: GitHub
ebeveyn 2621e0254e
işleme 88652b53d0
+9 -10
Dosyayı Görüntüle
@@ -474,26 +474,25 @@ void rcclSetWarpSpeedSupportAndFinalCuCount(struct ncclComm* comm, struct ncclKe
void rcclSetWarpSpeedAuto(struct ncclComm* comm, struct ncclTaskColl* info, size_t nBytes) {
info->useWarpSpeed = false;
if(!comm->topo->warpSpeedEnabled) {
return;
}
info->useWarpSpeed = (info->algorithm == NCCL_ALGO_RING); // Enabled by default for any RING algorithm when platform supports it
if(rcclParamWarpSpeedAutoMode() != 0 && IsArchMatch(comm->archName, "gfx950")) { // Auto mode only available for gfx950 currently
if(rcclParamWarpSpeedAutoMode() != 0) { // auto mode
if(!IsArchMatch(comm->archName, "gfx950")) {
// Auto mode only available for gfx950 currently, keep it to false
return;
}
size_t minBytes = 0;
commSetUnrollFactor(comm); // TODO: reset unroll factor per task rather than per comm
if(info->func == ncclFuncAllReduce || info->func == ncclFuncAllGather) minBytes = RCCL_WARP_SPEED_MIN_BYTES;
else if (info->func == ncclFuncReduceScatter) minBytes = RCCL_WARP_SPEED_MIN_BYTES << 2; // ReduceScatter requires higher message size to benefit from WarpSpeed
if(comm->nNodes == 1) {
if(nBytes >= minBytes && minBytes > 0) {
comm->unroll = NCCL_UNROLL_2;
info->nWarps = 4;
info->useWarpSpeed = true;
}
} else {
// TODO: set unroll factor per task rather than per comm
commSetUnrollFactor(comm);
info->useWarpSpeed = false;
}
} else if (comm->topo->warpSpeedEnabled && info->algorithm == NCCL_ALGO_RING) {
info->useWarpSpeed = true;
}
}
#endif