[WarpSpeed] Improve handling for auto and manual modes (#2125)
* Force ring in WarpSpeed manual mode and log event
* Skip usage for non-ring in WarpSpeed auto mode
* Enable WarpSpeed when its CU count is set
[ROCm/rccl commit: 93fdcb160c]
This commit is contained in:
zatwierdzone przez
GitHub
rodzic
49d9f8cc27
commit
5bba932529
@@ -1078,6 +1078,7 @@ NCCL_PARAM(AllocP2pNetLLBuffers, "ALLOC_P2P_NET_LL_BUFFERS", 0);
|
||||
#ifdef ENABLE_WARP_SPEED
|
||||
extern int64_t rcclParamWarpSpeedEnable();
|
||||
extern int64_t rcclParamWarpSpeedAutoMode();
|
||||
extern int64_t rcclParamWarpSpeedCuCount();
|
||||
#endif
|
||||
// MNNVL: Flag to indicate whether to enable Multi-Node NVLink
|
||||
NCCL_PARAM(MNNVLEnable, "MNNVL_ENABLE", 2);
|
||||
@@ -1458,7 +1459,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, struct ncclComm* p
|
||||
}
|
||||
}
|
||||
#ifdef ENABLE_WARP_SPEED
|
||||
comm->topo->warpSpeedEnabled = (rcclParamWarpSpeedEnable() != 0 || rcclParamWarpSpeedAutoMode() != 0);
|
||||
comm->topo->warpSpeedEnabled = (rcclParamWarpSpeedEnable() != 0 || rcclParamWarpSpeedAutoMode() != 0 || rcclParamWarpSpeedCuCount() > 0);
|
||||
#endif
|
||||
|
||||
// For single node communicators that do not uses the full xgmi links per gpu, i.e., nranks < 8
|
||||
|
||||
@@ -51,11 +51,19 @@ void rcclRestrictMaxChannels(struct ncclComm* comm, int& nc ) {
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool rcclCollSupportsRing(ncclFunc_t func) {
|
||||
return (func == ncclFuncAllReduce ||
|
||||
func == ncclFuncAllGather ||
|
||||
func == ncclFuncReduceScatter ||
|
||||
func == ncclFuncBroadcast ||
|
||||
func == ncclFuncReduce);
|
||||
}
|
||||
|
||||
int32_t rcclGetProtoForGfx12(ncclFunc_t collectiveFunc, size_t sizePerRank){
|
||||
int returnVal = NCCL_PROTO_SIMPLE;
|
||||
int SingleNodeLLCutoffs[] = {
|
||||
int SingleNodeLLCutoffs[] = {
|
||||
/*ncclFuncBroadcast*/ 1536,
|
||||
/*ncclFuncReduce*/ 8192,
|
||||
/*ncclFuncReduce*/ 8192,
|
||||
/*ncclFuncAllGather*/ 98304,
|
||||
/*ncclFuncReduceScatter*/ 98304,
|
||||
/*ncclFuncAllReduce*/ 913532,
|
||||
@@ -513,13 +521,20 @@ void rcclSetWarpSpeedSupportAndFinalCuCount(struct ncclComm* comm, struct ncclKe
|
||||
|
||||
void rcclSetWarpSpeedAuto(struct ncclComm* comm, struct ncclTaskColl* info, size_t nBytes) {
|
||||
info->useWarpSpeed = false;
|
||||
if(rcclParamWarpSpeedAutoMode() != 0) { // auto mode
|
||||
if(!rcclCollSupportsRing(info->func)) return;
|
||||
if(rcclParamWarpSpeedAutoMode() != 0) { // Auto performance mode
|
||||
if(!IsArchMatch(comm->archName, "gfx950")) {
|
||||
// Auto mode only available for gfx950 currently, keep it to false
|
||||
return;
|
||||
}
|
||||
size_t minBytes = 0;
|
||||
commSetUnrollFactor(comm); // TODO: reset unroll factor per task rather than per comm
|
||||
// No early return based on the algorithm at the start of the function
|
||||
// to allow unroll factor to be reverted to default.
|
||||
// This can be changed once per-task unroll factor setting is implemented.
|
||||
if(info->algorithm != NCCL_ALGO_RING) {
|
||||
return; // If Ring is not selected, assume it is suboptimal and return
|
||||
}
|
||||
if(info->func == ncclFuncAllReduce || info->func == ncclFuncAllGather) minBytes = RCCL_WARP_SPEED_MIN_BYTES;
|
||||
else if (info->func == ncclFuncReduceScatter) minBytes = RCCL_WARP_SPEED_MIN_BYTES << 2; // ReduceScatter requires higher message size to benefit from WarpSpeed
|
||||
if(comm->nNodes == 1) {
|
||||
@@ -529,7 +544,11 @@ void rcclSetWarpSpeedAuto(struct ncclComm* comm, struct ncclTaskColl* info, size
|
||||
info->useWarpSpeed = true;
|
||||
}
|
||||
}
|
||||
} else if (comm->topo->warpSpeedEnabled && info->algorithm == NCCL_ALGO_RING) {
|
||||
} else if (comm->topo->warpSpeedEnabled) {
|
||||
if(info->algorithm != NCCL_ALGO_RING) {
|
||||
INFO(NCCL_TUNING, "Overriding %s algorithm with RING for nccl%s at %zu bytes as WarpSpeed is requested and only supports RING", ncclAlgoToString(info->algorithm), ncclFuncToString(info->func), nBytes);
|
||||
info->algorithm = NCCL_ALGO_RING; // Force Ring when WarpSpeed is enabled in manual mode as it only supports Ring
|
||||
}
|
||||
info->useWarpSpeed = true;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user