[INIT] Fix fallback for unsupported user-specified runtime unroll factor (#1780)
* [INIT] Fix fallback for unsupported user-specified runtime unroll factor * Add CollTrace guard * Move `commSetUnrollFactor()` to rccl_wrap.cc * Modify comments in the device-code generator script
This commit is contained in:
@@ -52,19 +52,21 @@ else:
|
||||
# make ONLY_FUNCS="AllGather * *|Broadcast * *|SendRecv"
|
||||
#
|
||||
# # Only AllReduce Sum int32_t (but all algos, protos)
|
||||
# make ONLY_FUNCS="AllReduce * * Sum int32_t"
|
||||
# make ONLY_FUNCS="AllReduce * * Sum i32"
|
||||
#
|
||||
# # Only AllReduce RING Max float (but all protos and unrolls)
|
||||
# make ONLY_FUNCS="AllReduce RING * Max float"
|
||||
# make ONLY_FUNCS="AllReduce RING * Max f32"
|
||||
#
|
||||
# # AllReduce TREE LL128 Prod rccl_bfloat16 unroll=1
|
||||
# make ONLY_FUNCS="AllReduce TREE LL128 Prod rccl_bfloat16 1"
|
||||
# make ONLY_FUNCS="AllReduce TREE LL128 Prod bf16 1"
|
||||
#
|
||||
# # AllReduce RING SIMPLE and ReduceScatter RING LL float (but all redops, types, unrolls for AllReduce and all redops, unrolls for ReduceScatter)
|
||||
# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * float *"
|
||||
# make ONLY_FUNCS="AllReduce RING SIMPLE * *|ReduceScatter RING LL * f32 *"
|
||||
# --- or ---
|
||||
# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * float *"
|
||||
# make ONLY_FUNCS="AllReduce RING/TREE LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8 1/2/4|AllGather RING LL/SIMPLE Sum int8_t 1/2/4|AllToAllPivot RING SIMPLE Sum int8_t 1/2/4|Broadcast RING LL/SIMPLE Sum int8_t 1/2/4|Reduce RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8 1/2/4|ReduceScatter RING LL/SIMPLE Sum/MinMax int8_t/uint8_t/half/float/double/hip_bfloat16/rccl_float8/rccl_bfloat8 1/2/4|SendRecv RING SIMPLE Sum int8_t 1/2/4"
|
||||
# make ONLY_FUNCS="AllReduce RING SIMPLE|ReduceScatter RING LL * f32 *"
|
||||
#
|
||||
#
|
||||
# make ONLY_FUNCS="AllReduce RING/TREE LL/LL128/SIMPLE Sum/MinMax i8/u8/f16/f32/f64/bf16/f8e4m3/f8e5m2 1/2/4|AllGather RING LL/LL128/SIMPLE Sum i8 1/2/4|AllToAllPivot RING SIMPLE Sum i8 1/2/4|Broadcast RING LL/LL128/SIMPLE Sum i8 1/2/4|Reduce RING LL/LL128/SIMPLE Sum/MinMax i8/u8/f16/f32/f64/bf16/f8e4m3/f8e5m2 1/2/4|ReduceScatter RING LL/LL128/SIMPLE Sum/MinMax i8/u8/f16/f32/f64/bf16/f8e4m3/f8e5m2 1/2/4|SendRecv RING SIMPLE Sum i8 1/2/4"
|
||||
#
|
||||
# # ONLY_FUNCS can be used together for debugging
|
||||
|
||||
@@ -159,9 +161,6 @@ def func_filter(function_params, current_idx, item_list=None):
|
||||
|
||||
# If the paramter is equal to '*', include all possible cases for it
|
||||
if current_element == "*":
|
||||
if current_idx == 0:
|
||||
raise ValueError("Error: Paramter 'COLL' can not be type all '*'.")
|
||||
|
||||
# all_params list must be in the same order as function_params --> <coll> <algo> <proto> <redop> <type> <unroll>
|
||||
# Get the current list from all_params
|
||||
current_list = all_params[current_idx]
|
||||
|
||||
+4
-4
@@ -29,7 +29,7 @@
|
||||
using namespace rccl;
|
||||
|
||||
/* [RCCL] Determine which GPU kernel to execute */
|
||||
void* rcclGetKernelIndex(int unroll, bool useCollTrace, struct ncclTaskColl* task = NULL)
|
||||
void* rcclGetKernelIndex(int unroll, bool useCollTrace, struct ncclTaskColl* task)
|
||||
{
|
||||
// At this time, unroll factor is controlled only by passed in unroll argument
|
||||
// After more investigation, this may be further tuned by the actual task being processed
|
||||
@@ -48,9 +48,9 @@ void* rcclGetKernelIndex(int unroll, bool useCollTrace, struct ncclTaskColl* tas
|
||||
return rcclKernelTable[firstKernel + kernelIdx].funcPtr;
|
||||
}
|
||||
}
|
||||
// Fall back to default unroll
|
||||
WARN("Requested RCCL_UNROLL_FACTOR: %d does not exist in `rcclKernelTable`. Falling back to default unroll: %d", unroll, rcclKernelTable[firstKernel].unroll);
|
||||
return rcclKernelTable[firstKernel].funcPtr;
|
||||
|
||||
// If does not match, return null
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static int rcclProtoGrainSize(int proto, ncclComm *comm){
|
||||
|
||||
@@ -17,6 +17,8 @@
|
||||
#define NCCL_SIMPLE_ALIGNMENT (WARP_SIZE * 8LL * 16LL)
|
||||
#define NCCL_BYTES_ALIGNMENT 16
|
||||
|
||||
void* rcclGetKernelIndex(int unroll, bool useCollTrace, struct ncclTaskColl* task = NULL);
|
||||
|
||||
ncclResult_t ncclInitKernelsForDevice(int cudaArch, int maxSharedMem, size_t* maxStackSize);
|
||||
ncclResult_t ncclEnqueueCheck(struct ncclInfo* info);
|
||||
ncclResult_t ncclLaunchPrepare(struct ncclComm* comm);
|
||||
|
||||
@@ -23,6 +23,8 @@ THE SOFTWARE.
|
||||
#define RCCL_COMMON_H_
|
||||
#include "nccl_common.h"
|
||||
#include "nccl.h"
|
||||
#include "param.h"
|
||||
|
||||
typedef enum RcclTunableColls {
|
||||
RCCL_UNSUPPORTED_TUNABLE = -1,
|
||||
RCCL_RS_TUNABLE = 0, // reduce_scatter index
|
||||
@@ -78,4 +80,5 @@ ncclResult_t rcclGetAlgoInfo(struct ncclComm* comm, ncclFunc_t coll, uint64_t co
|
||||
int* algo, int* protocol, int* maxChannels);
|
||||
|
||||
ncclResult_t rcclFuncMaxSendRecvCount(ncclFunc_t func, int nRanks, size_t count, size_t& maxCount);
|
||||
#endif
|
||||
ncclResult_t commSetUnrollFactor(struct ncclComm* comm);
|
||||
#endif
|
||||
|
||||
+1
-31
@@ -51,6 +51,7 @@
|
||||
#include "mscclpp/mscclpp_nccl.h"
|
||||
#endif
|
||||
#include "rocm_smi_wrap.h"
|
||||
#include "rccl_common.h"
|
||||
// [/RCCL]
|
||||
|
||||
#include "msccl/msccl_lifecycle.h"
|
||||
@@ -87,37 +88,6 @@ NCCL_PARAM(RuntimeConnect, "RUNTIME_CONNECT", 1);
|
||||
struct allocationTracker allocTracker[MAX_ALLOC_TRACK_NGPU] = {};
|
||||
static ncclResult_t commReclaim(ncclComm_t comm);
|
||||
|
||||
//RCCL runtime param to set Unroll Factor
|
||||
RCCL_PARAM(UnrollFactor, "UNROLL_FACTOR", 0);
|
||||
|
||||
ncclResult_t commSetUnrollFactor(struct ncclComm* comm) {
|
||||
hipDeviceProp_t devProp;
|
||||
CUDACHECK(hipGetDeviceProperties(&devProp, comm->cudaDev));
|
||||
|
||||
//If RCCL runtime param is set, it will override defaults
|
||||
if (rcclParamUnrollFactor() != 0) {
|
||||
comm->unroll = rcclParamUnrollFactor();
|
||||
INFO(NCCL_INIT, "RCCL Unroll Factor (user-defined): %d", comm->unroll);
|
||||
}
|
||||
else {
|
||||
if (IsArchMatch(devProp.gcnArchName, "gfx950")) {
|
||||
//on gfx950, use unroll=1 for single-node and unroll=2 for multi-node
|
||||
if (comm->nNodes == 1)
|
||||
comm->unroll = 1;
|
||||
else
|
||||
comm->unroll = 2;
|
||||
}
|
||||
else if((IsArchMatch(devProp.gcnArchName, "gfx908")) ||
|
||||
(IsArchMatch(devProp.gcnArchName, "gfx942") && devProp.multiProcessorCount > 80))
|
||||
//on MI300X and gfx908, use unroll=2
|
||||
comm->unroll = 2;
|
||||
else
|
||||
comm->unroll = 4;
|
||||
INFO(NCCL_INIT, "RCCL Unroll Factor (pre-set): %d", comm->unroll);
|
||||
}
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_MSCCLPP
|
||||
size_t std::hash<ncclUniqueId>::operator ()(const ncclUniqueId& uniqueId) const noexcept {
|
||||
return (size_t)getHash(uniqueId.internal, NCCL_UNIQUE_ID_BYTES);
|
||||
|
||||
@@ -24,6 +24,7 @@ THE SOFTWARE.
|
||||
#include "comm.h"
|
||||
#include "graph/topo.h"
|
||||
#include "enqueue.h"
|
||||
|
||||
void rcclUpdateCollectiveProtocol(struct ncclComm* comm, size_t const& nBytes, struct ncclTaskColl* info) {
|
||||
// Honor user input for protocol choice
|
||||
static int userProtocolInput = -2;
|
||||
@@ -125,3 +126,44 @@ ncclResult_t rcclFuncMaxSendRecvCount(ncclFunc_t func, int nRanks, size_t count,
|
||||
maxCount = ncclFuncMaxSendRecvCount(func, nRanks, count);
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
//RCCL runtime param to set Unroll Factor
|
||||
RCCL_PARAM(UnrollFactor, "UNROLL_FACTOR", 0);
|
||||
|
||||
ncclResult_t commSetUnrollFactor(struct ncclComm* comm) {
|
||||
hipDeviceProp_t devProp;
|
||||
CUDACHECK(hipGetDeviceProperties(&devProp, comm->cudaDev));
|
||||
|
||||
//If RCCL runtime param is set, it will override defaults, if supported
|
||||
if (rcclParamUnrollFactor() != 0) {
|
||||
#if ENABLE_COLLTRACE
|
||||
if(rcclGetKernelIndex(rcclParamUnrollFactor(), comm->collTraceEnabled)) {
|
||||
#else
|
||||
if(rcclGetKernelIndex(rcclParamUnrollFactor(), false)) {
|
||||
#endif
|
||||
comm->unroll = rcclParamUnrollFactor();
|
||||
INFO(NCCL_INIT, "RCCL Unroll Factor (user-defined): %d", comm->unroll);
|
||||
return ncclSuccess;
|
||||
}
|
||||
else {
|
||||
// Fall back to default unroll
|
||||
WARN("Requested RCCL_UNROLL_FACTOR: %ld is invalid and does not exist in `rcclKernelTable`. Falling back to pre-set unroll.", rcclParamUnrollFactor());
|
||||
}
|
||||
}
|
||||
|
||||
if (IsArchMatch(devProp.gcnArchName, "gfx950")) {
|
||||
//on gfx950, use unroll=1 for single-node and unroll=2 for multi-node
|
||||
if (comm->nNodes == 1)
|
||||
comm->unroll = 1;
|
||||
else
|
||||
comm->unroll = 2;
|
||||
}
|
||||
else if((IsArchMatch(devProp.gcnArchName, "gfx908")) ||
|
||||
(IsArchMatch(devProp.gcnArchName, "gfx942") && devProp.multiProcessorCount > 80))
|
||||
//on MI300X and gfx908, use unroll=2
|
||||
comm->unroll = 2;
|
||||
else
|
||||
comm->unroll = 4;
|
||||
INFO(NCCL_INIT, "RCCL Unroll Factor (pre-set): %d", comm->unroll);
|
||||
return ncclSuccess;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user