işlemeyi yapan:
GitHub
ebeveyn
2621e0254e
işleme
88652b53d0
@@ -474,26 +474,25 @@ void rcclSetWarpSpeedSupportAndFinalCuCount(struct ncclComm* comm, struct ncclKe
|
||||
|
||||
void rcclSetWarpSpeedAuto(struct ncclComm* comm, struct ncclTaskColl* info, size_t nBytes) {
|
||||
info->useWarpSpeed = false;
|
||||
if(!comm->topo->warpSpeedEnabled) {
|
||||
return;
|
||||
}
|
||||
info->useWarpSpeed = (info->algorithm == NCCL_ALGO_RING); // Enabled by default for any RING algorithm when platform supports it
|
||||
if(rcclParamWarpSpeedAutoMode() != 0 && IsArchMatch(comm->archName, "gfx950")) { // Auto mode only available for gfx950 currently
|
||||
if(rcclParamWarpSpeedAutoMode() != 0) { // auto mode
|
||||
if(!IsArchMatch(comm->archName, "gfx950")) {
|
||||
// Auto mode only available for gfx950 currently, keep it to false
|
||||
return;
|
||||
}
|
||||
size_t minBytes = 0;
|
||||
commSetUnrollFactor(comm); // TODO: reset unroll factor per task rather than per comm
|
||||
if(info->func == ncclFuncAllReduce || info->func == ncclFuncAllGather) minBytes = RCCL_WARP_SPEED_MIN_BYTES;
|
||||
else if (info->func == ncclFuncReduceScatter) minBytes = RCCL_WARP_SPEED_MIN_BYTES << 2; // ReduceScatter requires higher message size to benefit from WarpSpeed
|
||||
if(comm->nNodes == 1) {
|
||||
if(nBytes >= minBytes && minBytes > 0) {
|
||||
comm->unroll = NCCL_UNROLL_2;
|
||||
info->nWarps = 4;
|
||||
info->useWarpSpeed = true;
|
||||
}
|
||||
} else {
|
||||
// TODO: set unroll factor per task rather than per comm
|
||||
commSetUnrollFactor(comm);
|
||||
info->useWarpSpeed = false;
|
||||
}
|
||||
} else if (comm->topo->warpSpeedEnabled && info->algorithm == NCCL_ALGO_RING) {
|
||||
info->useWarpSpeed = true;
|
||||
}
|
||||
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
Yeni konuda referans
Bir kullanıcı engelle