[WarpSpeed] Improve handling for auto and manual modes (#2125)

* Force ring in WarpSpeed manual mode and log event

* Skip usage for non-ring in WarpSpeed auto mode

* Enable WarpSpeed when its CU count is set
Este commit está contenido en:
Mustafa Abduljabbar
2026-01-06 10:21:49 -05:00
cometido por GitHub
padre b4a86ef680
commit 93fdcb160c
Se han modificado 2 ficheros con 25 adiciones y 5 borrados
+2 -1
Ver fichero
@@ -1078,6 +1078,7 @@ NCCL_PARAM(AllocP2pNetLLBuffers, "ALLOC_P2P_NET_LL_BUFFERS", 0);
#ifdef ENABLE_WARP_SPEED
extern int64_t rcclParamWarpSpeedEnable();
extern int64_t rcclParamWarpSpeedAutoMode();
extern int64_t rcclParamWarpSpeedCuCount();
#endif
// MNNVL: Flag to indicate whether to enable Multi-Node NVLink
NCCL_PARAM(MNNVLEnable, "MNNVL_ENABLE", 2);
@@ -1458,7 +1459,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p
}
}
#ifdef ENABLE_WARP_SPEED
comm->topo->warpSpeedEnabled = (rcclParamWarpSpeedEnable() != 0 || rcclParamWarpSpeedAutoMode() != 0);
comm->topo->warpSpeedEnabled = (rcclParamWarpSpeedEnable() != 0 || rcclParamWarpSpeedAutoMode() != 0 || rcclParamWarpSpeedCuCount() > 0);
#endif
// For single node communicators that do not uses the full xgmi links per gpu, i.e., nranks < 8
+23 -4
Ver fichero
@@ -51,11 +51,19 @@ void rcclRestrictMaxChannels(struct ncclComm* comm, int& nc ) {
}
}
static inline bool rcclCollSupportsRing(ncclFunc_t func) {
return (func == ncclFuncAllReduce ||
func == ncclFuncAllGather ||
func == ncclFuncReduceScatter ||
func == ncclFuncBroadcast ||
func == ncclFuncReduce);
}
int32_t rcclGetProtoForGfx12(ncclFunc_t collectiveFunc, size_t sizePerRank){
int returnVal = NCCL_PROTO_SIMPLE;
int SingleNodeLLCutoffs[] = {
int SingleNodeLLCutoffs[] = {
/*ncclFuncBroadcast*/ 1536,
/*ncclFuncReduce*/ 8192,
/*ncclFuncReduce*/ 8192,
/*ncclFuncAllGather*/ 98304,
/*ncclFuncReduceScatter*/ 98304,
/*ncclFuncAllReduce*/ 913532,
@@ -513,13 +521,20 @@ void rcclSetWarpSpeedSupportAndFinalCuCount(struct ncclComm* comm, struct ncclKe
void rcclSetWarpSpeedAuto(struct ncclComm* comm, struct ncclTaskColl* info, size_t nBytes) {
info->useWarpSpeed = false;
if(rcclParamWarpSpeedAutoMode() != 0) { // auto mode
if(!rcclCollSupportsRing(info->func)) return;
if(rcclParamWarpSpeedAutoMode() != 0) { // Auto performance mode
if(!IsArchMatch(comm->archName, "gfx950")) {
// Auto mode only available for gfx950 currently, keep it to false
return;
}
size_t minBytes = 0;
commSetUnrollFactor(comm); // TODO: reset unroll factor per task rather than per comm
// No early return based on the algorithm at the start of the function
// to allow unroll factor to be reverted to default.
// This can be changed once per-task unroll factor setting is implemented.
if(info->algorithm != NCCL_ALGO_RING) {
return; // If Ring is not selected, assume it is suboptimal and return
}
if(info->func == ncclFuncAllReduce || info->func == ncclFuncAllGather) minBytes = RCCL_WARP_SPEED_MIN_BYTES;
else if (info->func == ncclFuncReduceScatter) minBytes = RCCL_WARP_SPEED_MIN_BYTES << 2; // ReduceScatter requires higher message size to benefit from WarpSpeed
if(comm->nNodes == 1) {
@@ -529,7 +544,11 @@ void rcclSetWarpSpeedAuto(struct ncclComm* comm, struct ncclTaskColl* info, size
info->useWarpSpeed = true;
}
}
} else if (comm->topo->warpSpeedEnabled && info->algorithm == NCCL_ALGO_RING) {
} else if (comm->topo->warpSpeedEnabled) {
if(info->algorithm != NCCL_ALGO_RING) {
INFO(NCCL_TUNING, "Overriding %s algorithm with RING for nccl%s at %zu bytes as WarpSpeed is requested and only supports RING", ncclAlgoToString(info->algorithm), ncclFuncToString(info->func), nBytes);
info->algorithm = NCCL_ALGO_RING; // Force Ring when WarpSpeed is enabled in manual mode as it only supports Ring
}
info->useWarpSpeed = true;
}
}