Enable LL128 protocol support (#605)

* Enable LL128 protocol support

* Use shared memory object directly when possible
Šī revīzija ir iekļauta:
Wenkai Du
2022-09-08 14:45:27 -07:00
revīziju iesūtīja GitHub
vecāks d700a94918
revīzija 7bbce085cc
13 mainīti faili ar 382 papildinājumiem un 247 dzēšanām
@@ -574,7 +574,6 @@ template<typename T, typename RedOp>
struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SIMPLE> {
__device__ __attribute__((noinline)) void run(ncclWorkElem *args) {
using Proto = ProtoSimple<ALLREDUCE_CHUNKSTEPS/ALLREDUCE_SLICESTEPS, ALLREDUCE_SLICESTEPS>;
if (threadIdx.x == 0) __insert_timestamp(__LINE__);
runRing<T, RedOp, Proto>(args);
}
};
@@ -582,7 +581,6 @@ struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_SI
template<typename T, typename RedOp>
struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_TREE, NCCL_PROTO_SIMPLE> {
__device__ __attribute__((noinline)) void run(ncclWorkElem *args) {
if (threadIdx.x == 0) __insert_timestamp(__LINE__);
runTreeUpDown<T, RedOp, ProtoSimple<1, 1>>(args);
}
};
@@ -688,7 +686,6 @@ struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_COLLNET, NCCL_PROTO
template<typename T, typename RedOp>
struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL> {
__device__ __attribute__((noinline)) void run(ncclWorkElem *args) {
if (threadIdx.x == 0) __insert_timestamp(__LINE__);
runRing<T, RedOp, ProtoLL>(args);
}
};
@@ -696,7 +693,6 @@ struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_RING, NCCL_PROTO_LL
template<typename T, typename RedOp>
struct RunWorkElement<ncclFuncAllReduce, T, RedOp, NCCL_ALGO_TREE, NCCL_PROTO_LL> {
__device__ __attribute__((noinline)) void run(ncclWorkElem *args) {
if (threadIdx.x == 0) __insert_timestamp(__LINE__);
if (args->pad_0 == 0) runTreeUpDown<T, RedOp, ProtoLL>(args);
else runTreeSplit<T, RedOp, ProtoLL>(args);
}
+207 -111
Parādīt failu
@@ -102,26 +102,112 @@ static const __device__ constexpr ncclKernelFunc_t ncclFuncs[]{
#endif
};
template<unsigned short f, unsigned short l>
#define NCCL_FUNC5_LL128(func, algo, devredop, type, nullify) \
MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL, devredop, type)), \
MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, LL128, devredop, type)), \
MACRO_IF(nullify, nullptr, NCCL_FUNC_NAME(func, algo, SIMPLE, devredop, type))
#define NCCL_FUNC4_LL128(func, devredop, type, nullify) \
NCCL_FUNC5_LL128(func, TREE, devredop, type, nullify), \
NCCL_FUNC5_LL128(func, RING, devredop, type, nullify), \
NCCL_FUNC5_LL128(func, COLLNET, devredop, type, nullify)
// Must be consistent with ncclDataType_t
#define NCCL_FUNCS3A_LL128(func, devredop, nullForFloat) \
NCCL_FUNC4_LL128(func, devredop, int8_t, 0), \
NCCL_FUNC4_LL128(func, devredop, uint8_t, 0), \
NCCL_FUNC4_LL128(func, devredop, int32_t, 0), \
NCCL_FUNC4_LL128(func, devredop, uint32_t, 0), \
NCCL_FUNC4_LL128(func, devredop, int64_t, 0), \
NCCL_FUNC4_LL128(func, devredop, uint64_t, 0), \
NCCL_FUNC4_LL128(func, devredop, half, nullForFloat), \
NCCL_FUNC4_LL128(func, devredop, float, nullForFloat), \
NCCL_FUNC4_LL128(func, devredop, double, nullForFloat), \
NCCL_FUNC4_LL128(func, devredop, rccl_bfloat16, nullForFloat)
#define NCCL_FUNCS3B_LL128(func, devredop) \
NCCL_FUNC4_LL128(func, devredop, int8_t, 0), \
NCCL_FUNC4_LL128(func, devredop, int8_t, 0), \
NCCL_FUNC4_LL128(func, devredop, int8_t, 0), \
NCCL_FUNC4_LL128(func, devredop, int8_t, 0), \
NCCL_FUNC4_LL128(func, devredop, int8_t, 0), \
NCCL_FUNC4_LL128(func, devredop, int8_t, 0), \
NCCL_FUNC4_LL128(func, devredop, int8_t, 0), \
NCCL_FUNC4_LL128(func, devredop, int8_t, 0), \
NCCL_FUNC4_LL128(func, devredop, int8_t, 0), \
NCCL_FUNC4_LL128(func, devredop, int8_t, 0)
// Must be consistent with ncclRedOp_t
#define NCCL_FUNCS2A_LL128(func) \
NCCL_FUNCS3A_LL128(func, Sum, /*nullForFloat=*/0), \
NCCL_FUNCS3A_LL128(func, Prod, /*nullForFloat=*/0), \
NCCL_FUNCS3A_LL128(func, Max, /*nullForFloat=*/0), \
NCCL_FUNCS3A_LL128(func, Min, /*nullForFloat=*/0), \
NCCL_FUNCS3A_LL128(func, PreMulSum, /*nullForFloat=*/0), \
NCCL_FUNCS3A_LL128(func, SumPostDiv, /*nullForFloat=*/1)
#define NCCL_FUNCS2B_LL128(func) \
NCCL_FUNCS3B_LL128(func, Sum), \
NCCL_FUNCS3B_LL128(func, Sum), \
NCCL_FUNCS3B_LL128(func, Sum), \
NCCL_FUNCS3B_LL128(func, Sum), \
NCCL_FUNCS3B_LL128(func, Sum), \
NCCL_FUNCS3B_LL128(func, Sum)
// Must be consistent with the ncclFuncSet enum
using ncclKernelFunc_t = void (*)();
static const __device__ constexpr ncclKernelFunc_t ncclFuncs_ll128[]{
// Don't try to initialize the host shadow copy of this device-side global
// variable. There is no host pointer to a device-side function, which
// confuses clang. This will be fixed in the next clang release.
#if defined(__HIP_DEVICE_COMPILE__)
#if defined(BUILD_ALLREDUCE_ONLY)
NCCL_FUNC4_LL128(AllReduce, Sum, float, 0),
#else
NCCL_FUNCS2B_LL128(Broadcast),
NCCL_FUNCS2A_LL128(Reduce),
NCCL_FUNCS2B_LL128(AllGather),
NCCL_FUNCS2A_LL128(ReduceScatter),
NCCL_FUNCS2A_LL128(AllReduce),
NCCL_ONERANK_REDUCE_NAME(PreMulSum, int8_t),
NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint8_t),
NCCL_ONERANK_REDUCE_NAME(PreMulSum, int32_t),
NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint32_t),
NCCL_ONERANK_REDUCE_NAME(PreMulSum, int64_t),
NCCL_ONERANK_REDUCE_NAME(PreMulSum, uint64_t),
NCCL_ONERANK_REDUCE_NAME(PreMulSum, half),
NCCL_ONERANK_REDUCE_NAME(PreMulSum, float),
NCCL_ONERANK_REDUCE_NAME(PreMulSum, double),
#if defined(RCCL_BFLOAT16)
NCCL_ONERANK_REDUCE_NAME(PreMulSum, rccl_bfloat16),
#endif
NCCL_FUNC_NAME(SendRecv, RING, SIMPLE, Sum, int8_t),
NCCL_FUNC_NAME(AllToAllPivot, RING, SIMPLE, Sum, int8_t),
#endif
#endif
};
template<unsigned short f, unsigned short l, bool u>
struct Caller {
static __device__ __host__
void call(unsigned short funcIndex) noexcept
{
constexpr unsigned short m = f + (l - f) / 2;
return (funcIndex < m) ? Caller<f, m>::call(funcIndex) : Caller<m, l>::call(funcIndex);
return (funcIndex < m) ? Caller<f, m, u>::call(funcIndex) : Caller<m, l, u>::call(funcIndex);
}
};
template<unsigned short f>
struct Caller<f, f + 1>{
template<unsigned short f, bool u>
struct Caller<f, f + 1, u>{
static __device__ __host__
void call(unsigned short funcIndex) noexcept { ncclFuncs[f](); }
void call(unsigned short funcIndex) noexcept { if (u) ncclFuncs_ll128[f](); else ncclFuncs[f](); }
};
static_assert(FUNC_INDEX_P2P == 2710, "Wrong P2P function index");
static_assert(FUNC_INDEX_ALLTOALL_PIVOT == 2711, "Wrong AllToAllPivot function index");
template<bool USING_LL128>
inline
__device__
void NCCL_CALL_FUNCTIONS(unsigned short funcIndex) noexcept {
@@ -130,43 +216,59 @@ void NCCL_CALL_FUNCTIONS(unsigned short funcIndex) noexcept {
ncclFunction_AllReduce_RING_SIMPLE_Sum_float();
else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_LL))
ncclFunction_AllReduce_RING_LL_Sum_float();
else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_LL128))
else if (USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_LL128))
ncclFunction_AllReduce_RING_LL128_Sum_float();
else if (!USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_RING, NCCL_PROTO_LL128))
ncclFunction_AllReduce_RING_LL_Sum_float();
else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_TREE, NCCL_PROTO_SIMPLE))
ncclFunction_AllReduce_TREE_SIMPLE_Sum_float();
else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_TREE, NCCL_PROTO_LL))
ncclFunction_AllReduce_TREE_LL_Sum_float();
else if (USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_TREE, NCCL_PROTO_LL128))
ncclFunction_AllReduce_TREE_LL128_Sum_float();
else if (!USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_TREE, NCCL_PROTO_LL128))
ncclFunction_AllReduce_TREE_LL_Sum_float();
else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET, NCCL_PROTO_SIMPLE))
ncclFunction_AllReduce_COLLNET_SIMPLE_Sum_float();
else if (funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET, NCCL_PROTO_LL))
ncclFunction_AllReduce_COLLNET_LL_Sum_float();
else if (USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET, NCCL_PROTO_LL128))
ncclFunction_AllReduce_COLLNET_LL128_Sum_float();
else if (!USING_LL128 && funcIndex == FUNC_INDEX(ncclFuncAllReduce, ncclSum, ncclFloat32, NCCL_ALGO_COLLNET, NCCL_PROTO_LL128))
ncclFunction_AllReduce_COLLNET_LL_Sum_float();
else
assert("Unsupported function index");
#else
if (funcIndex < 540) {
if (funcIndex % 9 == 0) ncclFunction_Broadcast_TREE_LL_Sum_int8_t();
else if (funcIndex % 9 == 1) ncclFunction_Broadcast_TREE_LL_Sum_int8_t();
else if (USING_LL128 && funcIndex % 9 == 1) ncclFunction_Broadcast_TREE_LL128_Sum_int8_t();
else if (!USING_LL128 && funcIndex % 9 == 1) ncclFunction_Broadcast_TREE_LL_Sum_int8_t();
else if (funcIndex % 9 == 2) ncclFunction_Broadcast_TREE_SIMPLE_Sum_int8_t();
else if (funcIndex % 9 == 3) ncclFunction_Broadcast_RING_LL_Sum_int8_t();
else if (funcIndex % 9 == 4) ncclFunction_Broadcast_RING_LL_Sum_int8_t();
else if (USING_LL128 && funcIndex % 9 == 4) ncclFunction_Broadcast_RING_LL128_Sum_int8_t();
else if (!USING_LL128 && funcIndex % 9 == 4) ncclFunction_Broadcast_RING_LL_Sum_int8_t();
else if (funcIndex % 9 == 5) ncclFunction_Broadcast_RING_SIMPLE_Sum_int8_t();
else if (funcIndex % 9 == 6) ncclFunction_Broadcast_COLLNET_LL_Sum_int8_t();
else if (funcIndex % 9 == 7) ncclFunction_Broadcast_COLLNET_LL_Sum_int8_t();
else if (USING_LL128 && funcIndex % 9 == 7) ncclFunction_Broadcast_COLLNET_LL128_Sum_int8_t();
else if (!USING_LL128 && funcIndex % 9 == 7) ncclFunction_Broadcast_COLLNET_LL_Sum_int8_t();
else ncclFunction_Broadcast_COLLNET_SIMPLE_Sum_int8_t();
}
else if (funcIndex < 1080) Caller<540, 1080>::call(funcIndex);
else if (funcIndex < 1080) Caller<540, 1080, USING_LL128>::call(funcIndex);
else if (funcIndex < 1620) {
if (funcIndex % 9 == 0) ncclFunction_AllGather_TREE_LL_Sum_int8_t();
else if (funcIndex % 9 == 1) ncclFunction_AllGather_TREE_LL_Sum_int8_t();
else if (USING_LL128 && funcIndex % 9 == 1) ncclFunction_AllGather_TREE_LL128_Sum_int8_t();
else if (!USING_LL128 && funcIndex % 9 == 1) ncclFunction_AllGather_TREE_LL_Sum_int8_t();
else if (funcIndex % 9 == 2) ncclFunction_AllGather_TREE_SIMPLE_Sum_int8_t();
else if (funcIndex % 9 == 3) ncclFunction_AllGather_RING_LL_Sum_int8_t();
else if (funcIndex % 9 == 4) ncclFunction_AllGather_RING_LL_Sum_int8_t();
else if (USING_LL128 && funcIndex % 9 == 4) ncclFunction_AllGather_RING_LL128_Sum_int8_t();
else if (!USING_LL128 && funcIndex % 9 == 4) ncclFunction_AllGather_RING_LL_Sum_int8_t();
else if (funcIndex % 9 == 5) ncclFunction_AllGather_RING_SIMPLE_Sum_int8_t();
else if (funcIndex % 9 == 6) ncclFunction_AllGather_COLLNET_LL_Sum_int8_t();
else if (funcIndex % 9 == 7) ncclFunction_AllGather_COLLNET_LL_Sum_int8_t();
else if (USING_LL128 && funcIndex % 9 == 7) ncclFunction_AllGather_COLLNET_LL128_Sum_int8_t();
else if (!USING_LL128 && funcIndex % 9 == 7) ncclFunction_AllGather_COLLNET_LL_Sum_int8_t();
else ncclFunction_AllGather_COLLNET_SIMPLE_Sum_int8_t();
}
else if (funcIndex < 2700) Caller<1620, 2700>::call(funcIndex);
else if (funcIndex < 2700) Caller<1620, 2700, USING_LL128>::call(funcIndex);
else {
switch (funcIndex - 2700) {
case 0:
@@ -219,52 +321,52 @@ class ncclFunction {
#ifdef ENABLE_COLLTRACE
#define traceColl(elem,launch_type) \
uint32_t pos = __atomic_fetch_add(ncclShmem->comm.collTraceTail, 1, __ATOMIC_SEQ_CST)%COLLTRACE_NUM_ITEMS; \
ncclShmem->comm.collTrace[pos].timeStamp = __builtin_amdgcn_s_memrealtime(); \
ncclShmem->comm.collTrace[pos].bid = blockIdx.x; \
ncclShmem->comm.collTrace[pos].funcIndex = ncclShmem->work.header.funcIndex; \
asm volatile ("s_getreg_b32 %0, hwreg(HW_REG_HW_ID)" : "=s" (ncclShmem->comm.collTrace[pos].data_0)); \
uint32_t pos = __atomic_fetch_add(shmem.comm.collTraceTail, 1, __ATOMIC_SEQ_CST)%COLLTRACE_NUM_ITEMS; \
shmem.comm.collTrace[pos].timeStamp = __builtin_amdgcn_s_memrealtime(); \
shmem.comm.collTrace[pos].bid = blockIdx.x; \
shmem.comm.collTrace[pos].funcIndex = shmem.work.header.funcIndex; \
asm volatile ("s_getreg_b32 %0, hwreg(HW_REG_HW_ID)" : "=s" (shmem.comm.collTrace[pos].data_0)); \
if (elem.header.type == ncclWorkTypeP2p) { \
struct ncclWorkElemP2p *p2pElems = (struct ncclWorkElemP2p *)&elem; \
ncclShmem->comm.collTrace[pos].p2p[0].connIndex = p2pElems[0].connIndex; \
ncclShmem->comm.collTrace[pos].p2pOpCount[0] = p2pElems[0].opCount; \
ncclShmem->comm.collTrace[pos].p2p[0].ngroups = p2pElems[0].ngroups; \
ncclShmem->comm.collTrace[pos].p2p[0].nWarps = p2pElems[0].nWarps; \
ncclShmem->comm.collTrace[pos].p2p[0].warpStart = p2pElems[0].warpStart; \
ncclShmem->comm.collTrace[pos].p2p[0].peer = (uint16_t)(p2pElems[0].peer); \
ncclShmem->comm.collTrace[pos].p2p[1].connIndex = p2pElems[1].connIndex; \
ncclShmem->comm.collTrace[pos].p2pOpCount[1] = p2pElems[1].opCount; \
ncclShmem->comm.collTrace[pos].p2p[1].ngroups = p2pElems[1].ngroups; \
ncclShmem->comm.collTrace[pos].p2p[1].nWarps = p2pElems[1].nWarps; \
ncclShmem->comm.collTrace[pos].p2p[1].warpStart = p2pElems[1].warpStart; \
ncclShmem->comm.collTrace[pos].p2p[1].peer = (uint16_t)(p2pElems[1].peer); \
ncclShmem->comm.collTrace[pos].type = (ncclCollTraceP2pElemType|launch_type); \
shmem.comm.collTrace[pos].p2p[0].connIndex = p2pElems[0].connIndex; \
shmem.comm.collTrace[pos].p2pOpCount[0] = p2pElems[0].opCount; \
shmem.comm.collTrace[pos].p2p[0].ngroups = p2pElems[0].ngroups; \
shmem.comm.collTrace[pos].p2p[0].nWarps = p2pElems[0].nWarps; \
shmem.comm.collTrace[pos].p2p[0].warpStart = p2pElems[0].warpStart; \
shmem.comm.collTrace[pos].p2p[0].peer = (uint16_t)(p2pElems[0].peer); \
shmem.comm.collTrace[pos].p2p[1].connIndex = p2pElems[1].connIndex; \
shmem.comm.collTrace[pos].p2pOpCount[1] = p2pElems[1].opCount; \
shmem.comm.collTrace[pos].p2p[1].ngroups = p2pElems[1].ngroups; \
shmem.comm.collTrace[pos].p2p[1].nWarps = p2pElems[1].nWarps; \
shmem.comm.collTrace[pos].p2p[1].warpStart = p2pElems[1].warpStart; \
shmem.comm.collTrace[pos].p2p[1].peer = (uint16_t)(p2pElems[1].peer); \
shmem.comm.collTrace[pos].type = (ncclCollTraceP2pElemType|launch_type); \
} else { \
ncclShmem->comm.collTrace[pos].opCount = elem.opCount; \
ncclShmem->comm.collTrace[pos].coll.nWarps = elem.header.nWarps; \
ncclShmem->comm.collTrace[pos].coll.bid = elem.bid; \
ncclShmem->comm.collTrace[pos].coll.nChannels = elem.nChannels; \
ncclShmem->comm.collTrace[pos].type = (ncclCollTraceCollElemType|launch_type); \
shmem.comm.collTrace[pos].opCount = elem.opCount; \
shmem.comm.collTrace[pos].coll.nWarps = elem.header.nWarps; \
shmem.comm.collTrace[pos].coll.bid = elem.bid; \
shmem.comm.collTrace[pos].coll.nChannels = elem.nChannels; \
shmem.comm.collTrace[pos].type = (ncclCollTraceCollElemType|launch_type); \
}
#define traceKernelLaunch(elem,firstLaunch) { \
traceColl(elem,(firstLaunch?ncclCollTraceKernelLaunchType:ncclCollTraceCollLaunchType)); \
}
#define traceKernelEnd() { \
uint32_t pos = __atomic_fetch_add(ncclShmem->comm.collTraceTail, 1, __ATOMIC_SEQ_CST)%COLLTRACE_NUM_ITEMS; \
ncclShmem->comm.collTrace[pos].timeStamp = __builtin_amdgcn_s_memrealtime(); \
ncclShmem->comm.collTrace[pos].bid = bid; \
ncclShmem->comm.collTrace[pos].type = ncclCollTraceKernelEndType; \
uint32_t pos = __atomic_fetch_add(shmem.comm.collTraceTail, 1, __ATOMIC_SEQ_CST)%COLLTRACE_NUM_ITEMS; \
shmem.comm.collTrace[pos].timeStamp = __builtin_amdgcn_s_memrealtime(); \
shmem.comm.collTrace[pos].bid = bid; \
shmem.comm.collTrace[pos].type = ncclCollTraceKernelEndType; \
}
#define traceAbort() { \
uint32_t pos = __atomic_fetch_add(ncclShmem->comm.collTraceTail, 1, __ATOMIC_SEQ_CST)%COLLTRACE_NUM_ITEMS; \
ncclShmem->comm.collTrace[pos].timeStamp = __builtin_amdgcn_s_memrealtime(); \
ncclShmem->comm.collTrace[pos].bid = bid; \
ncclShmem->comm.collTrace[pos].type = ncclCollTraceAbortType; \
uint32_t pos = __atomic_fetch_add(shmem.comm.collTraceTail, 1, __ATOMIC_SEQ_CST)%COLLTRACE_NUM_ITEMS; \
shmem.comm.collTrace[pos].timeStamp = __builtin_amdgcn_s_memrealtime(); \
shmem.comm.collTrace[pos].bid = bid; \
shmem.comm.collTrace[pos].type = ncclCollTraceAbortType; \
}
// traceData(int16_t data2, uint32_t data4, uint64_t data8_0, uint64_t data8_1)
#define traceData(data2, data4, data8_0, data8_1) { \
uint32_t pos = __atomic_fetch_add(ncclShmem->comm.collTraceTail, 1, __ATOMIC_SEQ_CST)%COLLTRACE_NUM_ITEMS; \
uint32_t pos = atomicAdd(ncclShmem->comm.collTraceTail, 1)%COLLTRACE_NUM_ITEMS; \
ncclShmem->comm.collTrace[pos].bid = blockIdx.x; \
ncclShmem->comm.collTrace[pos].timeStamp = __builtin_amdgcn_s_memrealtime(); \
ncclShmem->comm.collTrace[pos].funcIndex = data2; \
@@ -281,22 +383,16 @@ class ncclFunction {
#ifdef ENABLE_PROFILING
#define __insert_timestamp(line_num) do { \
if (ncclShmem->prof.count < PROFILE_NUM_ITEMS) { \
ncclShmem->prof.elem[ncclShmem->prof.count].line = line_num; \
ncclShmem->prof.elem[ncclShmem->prof.count].timeStamp = __builtin_amdgcn_s_memrealtime(); \
ncclShmem->prof.count++; \
if (shmem.prof.count < PROFILE_NUM_ITEMS) { \
shmem.prof.elem[shmem.prof.count].line = line_num; \
shmem.prof.elem[shmem.prof.count].timeStamp = __builtin_amdgcn_s_memrealtime(); \
shmem.prof.count++; \
} \
} while(0);
#else
#define __insert_timestamp(line_num)
#endif
__device__ inline bool barrierReduceAny(int bit, uint32_t* abortCount) {
if (bit) atomicAdd(abortCount, 1);
__syncthreads();
return atomicAdd(abortCount, 0) != 0;
}
// Copy src to dst and fill extra size with zeroes
template<typename Tdst, typename Tsrc>
__device__ void copyToShmem(Tdst *dst, Tsrc const *src, int tid, int nthreads) {
@@ -385,19 +481,18 @@ struct ncclShmemGroup {
struct ncclShmemData {
union {
uint64_t ll128warp[NCCL_MAX_GROUPS][NCCL_MAX_GROUPS];
struct ncclShmemGroup groups[NCCL_MAX_GROUPS];
};
uint32_t sync[NCCL_MAX_GROUPS];
uint64_t redOpArgs[NCCL_MAX_DIRECT_ARITY+1];
struct ncclDevComm comm;
struct ncclChannel channel;
uint64_t pad;
uint64_t pad[2];
struct ncclWork work;
#ifdef ENABLE_PROFILING
struct ncclProf prof;
#endif
};
static_assert(offsetof(struct ncclShmemData, work)%16 == 0, "shmem.work needs to be 16B aligned");
static __device__ void ncclRedopPtrDeref(struct ncclWorkElem* we) {
if (we->header.type != ncclWorkTypeUnused && we->redOpArgIsPtr) {
@@ -422,16 +517,14 @@ static __device__ void ncclRedopPtrDeref(struct ncclWorkElem* we) {
extern __device__ struct ncclShmemData *ncclShmem;
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int FnIndex, bool COLLTRACE>
__device__ void ncclKernel(struct ncclDevComm* comm, ncclWorkElem first) {
template<ncclFunc_t Fn, typename T, typename RedOp, int Algo, int Proto, int FnIndex, bool COLLTRACE, bool USING_LL128>
__device__ void ncclKernel(struct ncclDevComm* comm) {
int tid = threadIdx.x;
int nthreads = blockDim.x;
int bid = blockIdx.x;
__shared__ struct ncclShmemData shmem;
ncclShmem = &shmem;
__shared__ uint32_t abortCount;
if (tid == 0) {
abortCount = 0;
for (auto i = 0; i < NCCL_MAX_GROUPS; i++) {
shmem.groups[i].barrier = 0;
for (auto j = 0; j < NCCL_MAX_GROUPS; j++) shmem.groups[i].barrier_next[j] = 0;
@@ -439,48 +532,38 @@ __device__ void ncclKernel(struct ncclDevComm* comm, ncclWorkElem first) {
}
__syncthreads();
int turn = copyToShmem(&ncclShmem->comm, comm);
int turn = copyToShmem(&shmem.comm, comm);
#ifdef ENABLE_PROFILING
if (tid == 0) {
ncclShmem->prof.count = 0;
ncclShmem->prof.seq = ncclShmem->comm.devProf[bid].seq;
shmem.prof.count = 0;
shmem.prof.seq = shmem.comm.devProf[bid].seq;
}
#endif
if (tid == 0) __insert_timestamp(__LINE__);
// get address of channel without incurring indirect load from ncclDevCom::channels
ncclChannel *channel = &((ncclDevCommAndChannels*)comm)->channels[bid];
turn = copyToShmem(&ncclShmem->channel, channel, turn);
turn = copyToShmem(&shmem.channel, channel, turn);
// To optimize for latency, (only) the first operation is passed as argument.
if (bid == 0 && first.header.type != ncclWorkTypeUnused) {
// Copy first elem to work and zero out the rest
copyToShmem(&ncclShmem->work, &first, tid, nthreads);
}
__syncthreads(); // publish ncclShmem
if (tid == 0) __insert_timestamp(__LINE__);
if (tid == 0) __insert_timestamp(__LINE__);
ncclWork *workFifoHost = ncclShmem->channel.workFifo;
ncclWork *workFifoDev = ncclShmem->channel.workFifoDev;
int workFifoIx = ncclShmem->channel.index;
bool skipLoadWork = false, firstLaunch = true;
if (bid == 0 && first.header.type != ncclWorkTypeUnused)
skipLoadWork = true;
ncclWork *workFifoHost = shmem.channel.workFifo;
ncclWork *workFifoDev = shmem.channel.workFifoDev;
int workFifoIx = shmem.channel.index;
bool firstLaunch = true;
while (true) {
if (!skipLoadWork) {
copyToShmem(&ncclShmem->work, &workFifoDev[workFifoIx], tid, nthreads);
if (tid == 0) __insert_timestamp(__LINE__);
{ // Check whether the last operation was aborted and make sure all threads exit
int aborted = tid == 0 ? *comm->abortFlag : 0;
if (barrierReduceAny(aborted, &abortCount)) { // publish ncclShmem->work
if (COLLTRACE && tid == 0) traceAbort();
break;
}
if (tid == 0)
workFifoHost[workFifoIx].header.type = ncclWorkTypeUnused;
copyToShmem(&shmem.work, &workFifoDev[workFifoIx], tid, nthreads);
if (tid == 0) __insert_timestamp(__LINE__);
{ // Check whether the last operation was aborted and make sure all threads exit
int aborted = tid == 0 ? *comm->abortFlag : 0;
if (__any(aborted)) { // publish shmem.work
if (COLLTRACE && tid == 0) traceAbort();
break;
}
if (tid == 0)
workFifoHost[workFifoIx].header.type = ncclWorkTypeUnused;
}
if (tid == 0) __insert_timestamp(__LINE__);
@@ -489,47 +572,59 @@ __device__ void ncclKernel(struct ncclDevComm* comm, ncclWorkElem first) {
channel->index = workFifoIx; // write back to real channel, not shmem shadow
__syncwarp();
if (ncclShmem->work.header.type == ncclWorkTypeColl) {
if (tid < NCCL_MAX_WORK_ELEMENTS) ncclRedopPtrDeref(&ncclShmem->work.elems[tid]);
} else if (ncclShmem->work.header.type == ncclWorkTypeRegColl) {
if (tid < NCCL_MAX_WORK_ELEMENTS_REG) ncclRedopPtrDeref(&ncclShmem->work.regElems[tid].elem);
if (shmem.work.header.type == ncclWorkTypeColl) {
if (tid < NCCL_MAX_WORK_ELEMENTS) ncclRedopPtrDeref(&shmem.work.elems[tid]);
} else if (shmem.work.header.type == ncclWorkTypeRegColl) {
if (tid < NCCL_MAX_WORK_ELEMENTS_REG) ncclRedopPtrDeref(&shmem.work.regElems[tid].elem);
}
__syncthreads();
if (COLLTRACE && tid == 0) {
traceKernelLaunch(ncclShmem->work.elems[0],firstLaunch);
traceKernelLaunch(shmem.work.elems[0],firstLaunch);
firstLaunch = false;
#pragma unroll 1
for(int e=1; e < NCCL_MAX_WORK_ELEMENTS && ncclShmem->work.elems[e].header.type != ncclWorkTypeUnused; e ++) {
traceColl(ncclShmem->work.elems[e], 0);
for(int e=1; e < NCCL_MAX_WORK_ELEMENTS && shmem.work.elems[e].header.type != ncclWorkTypeUnused; e ++) {
traceColl(shmem.work.elems[e], 0);
}
}
if (tid == 0) __insert_timestamp(__LINE__);
if (ncclShmem->work.header.funcIndex == FnIndex)
RunWork<Fn, T, RedOp, Algo, Proto>().run(&ncclShmem->work);
if (shmem.work.header.funcIndex == FnIndex)
RunWork<Fn, T, RedOp, Algo, Proto>().run(&shmem.work);
else
NCCL_CALL_FUNCTIONS(ncclShmem->work.header.funcIndex);
NCCL_CALL_FUNCTIONS<USING_LL128>(shmem.work.header.funcIndex);
if (ncclShmem->work.header.isLast) break;
if (shmem.work.header.isLast) break;
__syncthreads();
skipLoadWork = false;
}
if (COLLTRACE && tid == 0) traceKernelEnd()
#ifdef ENABLE_PROFILING
if (ncclShmem->comm.devProf->seq < PROFILE_NUM_LAUNCHES) {
copyToShmem(ncclShmem->comm.devProf+MAXCHANNELS*ncclShmem->prof.seq+blockIdx.x, &ncclShmem->prof);
if (tid == 0) ncclShmem->comm.devProf[bid].seq++;
if (shmem.comm.devProf->seq < PROFILE_NUM_LAUNCHES) {
__syncthreads();
copyToShmem(shmem.comm.devProf+MAXCHANNELS*shmem.prof.seq+blockIdx.x, &shmem.prof);
if (tid == 0) shmem.comm.devProf[bid].seq++;
}
#endif
}
#define IMPL_COLL_KERN(func, algo, proto, devredop, type, fIndex) \
__launch_bounds__(NCCL_MAX_NTHREADS, 1) \
__global__ void NCCL_KERN_NAME(func, algo, proto, devredop, type)(struct ncclDevComm* comm, ncclWorkElem first) { \
if (comm->collTraceThread) \
ncclKernel<ncclFunc##func, type, Func##devredop<type>, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, true>(comm, first); \
else \
ncclKernel<ncclFunc##func, type, Func##devredop<type>, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, false>(comm, first); \
__global__ void NCCL_KERN_NAME(func, algo, proto, devredop, type)(struct ncclDevComm* comm) { \
ncclKernel<ncclFunc##func, type, Func##devredop<type>, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, false, false>(comm); \
} \
\
__launch_bounds__(NCCL_MAX_NTHREADS, 1) \
__global__ void NCCL_KERN_NAME_DEBUG(func, algo, proto, devredop, type)(struct ncclDevComm* comm) { \
ncclKernel<ncclFunc##func, type, Func##devredop<type>, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, true, false>(comm); \
} \
\
__launch_bounds__(NCCL_MAX_NTHREADS, 1) \
__global__ void NCCL_KERN_NAME_LL128(func, algo, proto, devredop, type)(struct ncclDevComm* comm) { \
ncclKernel<ncclFunc##func, type, Func##devredop<type>, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, false, true>(comm); \
} \
\
__launch_bounds__(NCCL_MAX_NTHREADS, 1) \
__global__ void NCCL_KERN_NAME_LL128_DEBUG(func, algo, proto, devredop, type)(struct ncclDevComm* comm) { \
ncclKernel<ncclFunc##func, type, Func##devredop<type>, NCCL_ALGO_##algo, NCCL_PROTO_##proto, fIndex, true, true>(comm); \
}
// Examples : AllReduce, RING, LL, Sum, uint8
@@ -542,7 +637,8 @@ __device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, dev
// Only generate inline kernels for LL
#define IMPL_COLL4(func, algo, devredop, type, ncclType) \
IMPL_COLL_FUNC(func, algo, LL, devredop, type) \
IMPL_COLL_FUNC(func, algo, SIMPLE, devredop, type) \
IMPL_COLL_FUNC(func, algo, LL128, devredop, type) \
IMPL_COLL_FUNC(func, algo, SIMPLE, devredop, type)
#define IMPL_COLL3(func, devredop, type, ncclType) \
IMPL_COLL4(func, TREE, devredop, type, ncclType) \
-19
Parādīt failu
@@ -9,25 +9,6 @@
#define OP128_H_
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
inline __device__ uint64_t* shmemCvtPtr(volatile uint64_t* shmemGenericPtr) {
return 0;
}
inline __device__ void load128(const uint64_t* ptr, uint64_t &v0, uint64_t &v1) {
}
inline __device__ void store128(uint64_t* ptr, uint64_t v0, uint64_t v1) {
}
inline __device__ void loadShmem128(uint64_t* shmemAsmPtr, uint64_t &v0, uint64_t &v1) {
}
inline __device__ void storeShmem128(uint64_t* shmemAsmPtr, uint64_t v0, uint64_t v1) {
}
template<typename T>
inline __device__ void loadShmemMisaligned128(T *ptr, uint64_t &v0, uint64_t &v1) {
}
#else
inline __device__ void load128(const uint64_t* ptr, uint64_t &v0, uint64_t &v1) {
asm volatile("ld.volatile.global.v2.u64 {%0,%1}, [%2];"
+105 -71
Parādīt failu
@@ -9,7 +9,6 @@
#define NCCL_LL128_FLAGTHREAD (NCCL_LL128_LINEELEMS-1)
#define __any_sync(WARP_MASK, needReload) (true)
template<typename T, typename RedOp, typename Fan, int Direct, int P2p>
class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p>:
@@ -51,20 +50,25 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p>:
inline __device__ uint64_t recvFlag(int i) { return recvStep[i]+1; }
inline __device__ uint64_t sendFlag(int i) { return sendStep[i]+1; }
uint64_t* barriers;
uint64_t* barrier_next;
inline __device__ void barrier() {
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
__syncthreads();
if (nthreads != WARP_SIZE)
barrier_by_group();
#else
asm volatile ("bar.sync %1, %0;" :: "r"(nthreads), "r"(15-group));
#endif
}
uint32_t abort = 0;
uint32_t* sync;
inline __device__ int checkAbort(int &spins, int i, int send) {
spins++;
if (abort == 0 && spins == NCCL_SPINS_BEFORE_CHECK_ABORT) {
abort = *ncclShmem->comm.abortFlag;
abort = __atomic_load_n(ncclShmem->comm.abortFlag, __ATOMIC_SEQ_CST);
spins = 0;
}
return abort;
@@ -74,67 +78,69 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p>:
if (sendConnHeadPtr) {
int spins = 0;
while (sendConnHeadCache + NCCL_STEPS < sendConnHead + 1) {
sendConnHeadCache = *sendConnHeadPtr;
__builtin_amdgcn_s_sleep(8);
sendConnHeadCache = atomicAdd_system((unsigned long long *)sendConnHeadPtr, 0);
if (checkAbort(spins, wid, 1)) break;
}
__asm__ __volatile__("s_wakeup");
if (sendConnFifoPtr) {
sendConnFifoPtr[sendStep[wid]%NCCL_STEPS] = nbytes;
__atomic_store_n(sendConnFifoPtr+sendStep[wid]%NCCL_STEPS, nbytes, __ATOMIC_SEQ_CST);
}
sendConnHead += 1;
}
}
inline __device__ void postRecv() {
if (recvConnHeadPtr) *recvConnHeadPtr = recvConnHead += 1;
if (recvConnHeadPtr) atomicExch_system((unsigned long long *)recvConnHeadPtr, recvConnHead += 1);
}
inline __device__ void postSend() {
if (sendConnTailPtr) { __threadfence(); *sendConnTailPtr = sendConnTail += 1; }
if (sendConnTailPtr) { __threadfence(); atomicExch_system((unsigned long long *)sendConnTailPtr, sendConnTail += 1); }
}
template<int WordPerThread>
__device__ __forceinline__ void loadRegsBegin(uint64_t(&regs)[WordPerThread], T const *src, int eltN) {
constexpr int EltPer16B = 16/sizeof(T);
if(reinterpret_cast<uintptr_t>(src)%16 == 0) {
/* We are aligned to 16 bytes, so load directly to registers no shmem.
* Flag threads load half as much data which gets shuffled to the even
* registers during Finish. The point of splitting into two phases is to
* defer that shuffle, which incurs a dependency stall, until after other
* memops are launched by the caller.
*/
#pragma unroll
for(int g=0; g < WordPerThread/2; g++) {
int ix = g*WARP_SIZE - 4*(g/2) + wid - (g%2)*(wid/8);
if(!flagThread || g%2==0) {
if(ix*EltPer16B < eltN)
load128((uint64_t*)(src + ix*EltPer16B), regs[2*g+0], regs[2*g+1]);
}
}
}
else {
// Not aligned. Stage the smallest 16 byte aligned region subsuming the
// buffer into shmem.
int misalignment = reinterpret_cast<uintptr_t>(src) % 16;
uint64_t *src8 = reinterpret_cast<uint64_t*>(reinterpret_cast<uintptr_t>(src) & -uintptr_t(16));
uint64_t *shm8 = shmemCvtPtr(ncclShmem->ll128warp[warp]);
#pragma unroll
for(int g=0; g < WordPerThread/2; g++)
if((g*WARP_SIZE + wid)*16 < misalignment + eltN*sizeof(T))
load128(src8 + 2*(g*WARP_SIZE + wid), regs[2*g+0], regs[2*g+1]);
#pragma unroll
for(int g=0; g < WordPerThread/2; g++)
storeShmem128(shm8 + 2*(g*WARP_SIZE + wid), regs[2*g+0], regs[2*g+1]);
__syncwarp();
// Now load from shmem stage to regs. Preserve the same pre-shuffled layout
// as the aligned case since Finish() will be applied regardless.
T *shm = (T*)shm8 + misalignment/sizeof(T);
#pragma unroll
for(int g=0; g < WordPerThread/2; g++) {
int ix = g*WARP_SIZE - 4*(g/2) + wid - (g%2)*(wid/8);
if(!flagThread || g%2==0) {
if(ix*EltPer16B < eltN)
loadShmemMisaligned128(shm + ix*EltPer16B, regs[2*g+0], regs[2*g+1]);
/* We are aligned to 16 bytes, so load directly to registers no shmem.
* Flag threads load half as much data which gets shuffled to the even
* registers during Finish. The point of splitting into two phases is to
* defer that shuffle, which incurs a dependency stall, until after other
* memops are launched by the caller.
*/
#pragma unroll
for(int g=0; g < WordPerThread/2; g++) {
int ix = g*WARP_SIZE - 16*(g/2) + wid - (g%2)*(wid/4);
if(!flagThread || g%2==0) {
if(ix*EltPer16B < eltN) {
if(reinterpret_cast<uintptr_t>(src)%4 == 0) {
regs[2*g+0] = __builtin_nontemporal_load((uint64_t*)(src + ix*EltPer16B));
regs[2*g+1] = __builtin_nontemporal_load((uint64_t*)(src + ix*EltPer16B)+1);
} else {
union {
uint64_t regs64[WordPerThread];
uint32_t regs32[WordPerThread*2];
uint16_t regs16[WordPerThread*4];
uint8_t regs8[WordPerThread*8];
};
if (sizeof(T) == 8) {
uint64_t *src64 = (uint64_t*)(src+ix*EltPer16B);
for (int i=0; i < 2; i++)
regs64[2*g+i] = __builtin_nontemporal_load(src64+i);
} else if (sizeof(T) == 4) {
uint32_t *src32 = (uint32_t*)(src+ix*EltPer16B);
for (int i=0; i < 2*sizeof(uint64_t)/sizeof(T); i++)
regs32[2*g+i] = __builtin_nontemporal_load(src32+i);
} else if (sizeof(T) == 2) {
uint16_t *src16 = (uint16_t*)(src+ix*EltPer16B);
for (int i=0; i < 2*sizeof(uint64_t)/sizeof(T); i++)
regs16[2*g+i] = __builtin_nontemporal_load(src16+i);
} else if (sizeof(T) == 1) {
uint8_t *src8 = (uint8_t*)(src+ix*EltPer16B);
for (int i=0; i < 2*sizeof(uint64_t)/sizeof(T); i++)
regs8[2*g+i] = __builtin_nontemporal_load(src8+i);
}
regs[2*g+0] = regs64[2*g+0];
regs[2*g+1] = regs64[2*g+1];
}
}
}
}
@@ -157,26 +163,45 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p>:
for (int g=1; g < WordPerThread/2; g+=2) {
if (flagThread) regs[2*g-1] = regs[2*g];
}
// Write to dst if 16-byte aligned, shmem otherwise.
int misalignment = reinterpret_cast<uintptr_t>(dst)%16;
uint64_t *shm8 = shmemCvtPtr(ncclShmem->ll128warp[warp]);
// Write to dst if 4-byte aligned, shmem otherwise.
int misalignment = reinterpret_cast<uintptr_t>(dst)%4;
#pragma unroll
for(int g=0; g < WordPerThread/2; g++) {
int ix = g*WARP_SIZE - 4*(g/2) + wid - (g%2)*(wid/8);
int ix = g*WARP_SIZE - 16*(g/2) + wid - (g%2)*(wid/4);
if (!flagThread || g%2==0) {
if(misalignment == 0 && (ix+1)*EltPer16B <= eltN)
store128((uint64_t*)(dst + ix*EltPer16B), regs[2*g+0], regs[2*g+1]);
else
storeShmem128(shm8+2*ix, regs[2*g+0], regs[2*g+1]);
if(misalignment == 0 && (ix+1)*EltPer16B <= eltN) {
__builtin_nontemporal_store(regs[2*g+0], (uint64_t*)(dst + ix*EltPer16B));
__builtin_nontemporal_store(regs[2*g+1], (uint64_t*)(dst + ix*EltPer16B)+1);
} else {
union {
uint64_t regs64[WordPerThread];
uint32_t regs32[WordPerThread*2];
uint16_t regs16[WordPerThread*4];
uint8_t regs8[WordPerThread*8];
};
regs64[2*g+0] = regs[2*g+0];
regs64[2*g+1] = regs[2*g+1];
int remaining = eltN - ix*EltPer16B;
if (sizeof(T) == 8) {
uint64_t *dst64 = (uint64_t*)(dst+ix*EltPer16B);
for (int i=0; i < 2 && i < remaining; i++)
__builtin_nontemporal_store(regs64[2*g+i], dst64+i);
} else if (sizeof(T) == 4) {
uint32_t *dst32 = (uint32_t*)(dst+ix*EltPer16B);
for (int i=0; i < 2*sizeof(uint64_t)/sizeof(T) && i < remaining; i++)
__builtin_nontemporal_store(regs32[2*g+i], dst32+i);
} else if (sizeof(T) == 2) {
uint16_t *dst16 = (uint16_t*)(dst+ix*EltPer16B);
for (int i=0; i < 2*sizeof(uint64_t)/sizeof(T) && i < remaining; i++)
__builtin_nontemporal_store(regs16[2*g+i], dst16+i);
} else if (sizeof(T) == 1) {
uint8_t *dst8 = (uint8_t*)(dst+ix*EltPer16B);
for (int i=0; i < 2*sizeof(uint64_t)/sizeof(T) && i < remaining; i++)
__builtin_nontemporal_store(regs8[2*g+i], dst8+i);
}
}
}
}
__syncwarp();
// Write rest from shmem to dst. No need to coalesce stores to 16-bytes,
// the hardware keeps up fine.
T *shm = (T*)ncclShmem->ll128warp[warp];
int skip = misalignment == 0 ? eltN & -EltPer16B : 0;
for(int i=skip+wid; i < eltN; i += WARP_SIZE)
dst[i] = shm[i];
}
#define WARP_MASK 0xffffffff
@@ -197,10 +222,11 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p>:
needReload = false;
#pragma unroll
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
load128(ptr+u*WARP_SIZE, vr[u], vr[u+1]);
vr[u] = __builtin_nontemporal_load(ptr+u*WARP_SIZE);
vr[u+1] = __builtin_nontemporal_load(ptr+u*WARP_SIZE+1);
needReload |= flagThread && (vr[u+1] != flag);
}
} while (__any_sync(WARP_MASK, needReload) && checkAbort(spins, 0, 0) == 0);
} while (__any(needReload) && checkAbort(spins, 0, 0) == 0);
}
/************* Finish register load **************/
@@ -238,10 +264,11 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p>:
needReload = false;
#pragma unroll
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
load128(ptr+u*WARP_SIZE, vr[u], vr[u+1]);
vr[u] = __builtin_nontemporal_load(ptr+u*WARP_SIZE);
vr[u+1] = __builtin_nontemporal_load(ptr+u*WARP_SIZE+1);
needReload |= flagThread && (vr[u+1] != flag);
}
} while (__any_sync(WARP_MASK, needReload) && checkAbort(spins, i, 0) == 0);
} while (__any(needReload) && checkAbort(spins, i, 0) == 0);
#pragma unroll
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
@@ -260,6 +287,9 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p>:
}
}
#if !defined(__gfx1030__)
if (tid == 0) __asm__ __volatile__("buffer_wbinvl1_vol");
#endif
/************************ Send **************************/
if (SEND) {
for (int i=1; i<MaxSend && i<fan.nsend(); i++) {
@@ -267,14 +297,16 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p>:
uint64_t* ptr = sendPtr(i)+ll128Offset;
#pragma unroll
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
store128(ptr+u*WARP_SIZE, v[u], flagThread ? flag : v[u+1]);
__builtin_nontemporal_store(v[u], ptr+u*WARP_SIZE);
__builtin_nontemporal_store(flagThread ? flag : v[u+1], ptr+u*WARP_SIZE+1);
}
}
uint64_t flag = sendFlag(0);
uint64_t* ptr = sendPtr(0)+ll128Offset;
#pragma unroll
for (int u=0; u<ELEMS_PER_THREAD; u+=2) {
store128(ptr+u*WARP_SIZE, v[u], flagThread ? flag : v[u+1]);
__builtin_nontemporal_store(v[u], ptr+u*WARP_SIZE);
__builtin_nontemporal_store(flagThread ? flag : v[u+1], ptr+u*WARP_SIZE+1);
}
}
/********************** End Send ************************/
@@ -361,12 +393,14 @@ class Primitives<T, RedOp, Fan, Direct, ProtoLL128, P2p>:
public:
__device__ Primitives(
const int tid, const int nthreads, int const *recvPeers, int const *sendPeers,
void const *inputBuf, void *outputBuf, uint64_t redOpArg, int group=0, int connIndex=0
void const *inputBuf, void *outputBuf, uint64_t redOpArg, int group=0
):
redOp(redOpArg),
tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), warp(tid/WARP_SIZE),
flagThread((tid%8)==7), group(group),
flagThread((tid%4)==3), group(group&(uint16_t)0xFFFF),
stepSize(ncclShmem->comm.buffSizes[NCCL_PROTO_LL128]/NCCL_STEPS/sizeof(uint64_t)) {
barriers = &ncclShmem->groups[this->group].barrier;
barrier_next = ncclShmem->groups[this->group].barrier_next;
auto *channel = &ncclShmem->channel;
int nrecv=0, nsend=0;
+20 -24
Parādīt failu
@@ -68,10 +68,13 @@
NCCL_FUNCS3B(func, Sum), /*PreMulSum*/ \
NCCL_FUNCS3B(func, Sum) /*SumPostDiv*/
typedef void(*ncclKern_t)(struct ncclDevComm* comm, struct ncclWorkElem first);
typedef void(*ncclKern_t)(struct ncclDevComm* comm);
// Must be consistent with the ncclFuncSet enum
static ncclKern_t const ncclKerns[1] = {
static ncclKern_t const ncclKerns[4] = {
NCCL_KERN_NAME(SendRecv, RING, SIMPLE, Sum, int8_t),
NCCL_KERN_NAME_DEBUG(SendRecv, RING, SIMPLE, Sum, int8_t),
NCCL_KERN_NAME_LL128(SendRecv, RING, SIMPLE, Sum, int8_t),
NCCL_KERN_NAME_LL128_DEBUG(SendRecv, RING, SIMPLE, Sum, int8_t),
};
// Determine the maximum kernel stack size of all CUDA kernels
@@ -89,6 +92,19 @@ error:
return (res != ncclSuccess) ? 0 : max;
}
// Determine kernel stack size from index
size_t ncclKernLocalSize(int i) {
ncclResult_t res = ncclSuccess;
int numNcclKerns = sizeof(ncclKerns)/sizeof(ncclKerns[0]);
hipFuncAttributes attr = {0};
if (i < numNcclKerns)
CUDACHECKGOTO(hipFuncGetAttributes(&attr, (const void*)(ncclKerns[i])), res, error);
error:
return (res != ncclSuccess) ? 0 : attr.localSizeBytes;
}
// Set shared memory carveout for the nccl kernels
ncclResult_t ncclKernSetSharedMemoryCarveout(int carveOut) {
ncclResult_t res = ncclSuccess;
@@ -174,14 +190,6 @@ static ncclResult_t setupLaunch(struct ncclQueueInfo* eqInfo, int usingCudaGraph
}
channel->workFifo[(channel->workFifoTail-1)%NCCL_MAX_OPS].header.isLast = 1;
if (c == 0) {
// As we inline the first coll directly, we can free it immediately.
// Except P2P or aggregation or registration cases
struct ncclWork* work = channel->workFifo+((channel->workFifoTail-channel->workCount)%NCCL_MAX_OPS);
if (work->header.type == ncclWorkTypeColl && eqInfo->elemList->count() == 1)
work->header.type = ncclWorkTypeUnused;
}
if (channel->gdrMemDesc) {
// GDRCOPY support
uint64_t first = (channel->workFifoTail-channel->workCount)%NCCL_MAX_OPS;
@@ -762,14 +770,7 @@ static ncclResult_t ncclSetupCollKernel(struct ncclInfo* info) {
// Inline the first kernel
if (params->func == NULL) {
params->func = (void *)ncclKerns[0];
if (work->header.type == ncclWorkTypeColl) {
// Copy the first operation to the inline argument. Type may be set later to
// ncclWorkTypeUnused if we have more than one coll element.
memcpy(&comm->args, work->elems, sizeof(struct ncclWorkElem));
comm->args.bid = 0; // Only inline for channel 0
comm->args.header.isLast = 1; // I am so far the last element
}
params->func = (void *)ncclKerns[ncclGetKernelIndex(comm)];
}
// Register and exchange input and output buffers
@@ -781,9 +782,6 @@ static ncclResult_t ncclSetupCollKernel(struct ncclInfo* info) {
NCCLCHECK(ncclRegBuffAndExchange(info, &eqElem->buffRegInfo));
comm->enqueueInfo->nRegBuffs += eqElem->buffRegInfo.nBuffs;
work->header.type = ncclWorkTypeRegColl;
// Disable inline argument because we need kernel to copy the entire ncclWork from workFifo
// because the registered addresses are in ncclWorkElemReg
comm->args.header.type = ncclWorkTypeUnused;
}
return ncclSuccess;
@@ -883,7 +881,6 @@ ncclResult_t ncclSetupAsyncKernels(ncclComm_t comm) {
}
NCCLCHECK(ncclSetupCollKernel(info));
}
comm->args.header.type = ncclWorkTypeUnused; // disable inline argument
}
// Reset counters
comm->asyncOpCount = 0;
@@ -1106,9 +1103,8 @@ ncclResult_t ncclSetupP2pKernel(struct ncclInfo* info) {
// Just for CUDA kernel to know this is a P2P operation
// The CUDA kernel does not use the inlined first work element as fastpath argument
if (params->func == NULL) {
params->func = (void *)ncclKerns[0];
params->func = (void *)ncclKerns[ncclGetKernelIndex(comm)];
//params->func = ncclKerns[eqElem->work.header.funcIndex];
comm->args.header.type = ncclWorkTypeUnused;
}
return ncclSuccess;
}
+2
Parādīt failu
@@ -838,6 +838,8 @@ static void parseOptions(struct ncclTopoSystem* system, const char *options) {
system->pivotA2ANumBiRings = atol(tokens[i*2+1]);
} else if (strcmp(tokens[i*2], "tuning") == 0) {
system->tuning = atol(tokens[i*2+1]);
} else if (strcmp(tokens[i*2], "ll128Enabled") == 0) {
system->ll128Enabled = (bool)atol(tokens[i*2+1]);
}
}
free(str_temp);
+5
Parādīt failu
@@ -165,6 +165,7 @@ struct ncclTopoSystem {
bool pivotA2AEnabled;
int pivotA2ANumBiRings;
bool ll128Enabled;
};
ncclResult_t ncclTopoGetNode(struct ncclTopoSystem* system, struct ncclTopoNode** node, int type, uint64_t id);
@@ -209,4 +210,8 @@ static ncclResult_t ncclTopoRankToIndex(struct ncclTopoSystem* system, int rank,
static float ncclTopoXGMISpeed(int gcn) {
return gcn == 910 ? MI200_XGMI_WIDTH : VEGA_XGMI_WIDTH;
}
#define ncclGetKernelIndex(p_comm) \
(((p_comm)->topo->ll128Enabled ? 1 : 0)*2 + ((p_comm)->hostDevComm.collTraceThread ? 1 : 0))
#endif
+17 -10
Parādīt failu
@@ -191,30 +191,30 @@ static struct tuningModel tuning_model_3 {
static struct tuningModel tuning_model_4 {
.hwLat = {
/* NVLINK */
{ /* Tree (LL/LL128/Simple)*/ { 0.8, 0.0, 2.5 }, /* Ring (LL/LL128/Simple)*/ { 0.8, 0.0, 3.6 }, /* CollNet (LL/LL128/Simple)*/ { 0.8, 0.0, 2.5 } },
{ /* Tree (LL/LL128/Simple)*/ { 0.8, 1.4, 2.5 }, /* Ring (LL/LL128/Simple)*/ { 0.8, 2.2, 3.6 }, /* CollNet (LL/LL128/Simple)*/ { 0.8, 1.4, 2.5 } },
/* PCI */
{ /* Tree (LL/LL128/Simple)*/ { 2.2, 2.2, 5.7 }, /* Ring (LL/LL128/Simple)*/ { 2.2, 2.2, 5.7 }, /* CollNet (LL/LL128/Simple)*/ { 2.2, 2.2, 5.7 } },
/* NET */
{ /* Tree (LL/LL128/Simple)*/ { 45.8, 0.0, 105.0 }, /* Ring (LL/LL128/Simple)*/ { 19.2, 0.0, 51.0 }, /* CollNet (LL/LL128/Simple)*/ { 45.8, 0.0, 105.0 } },
{ /* Tree (LL/LL128/Simple)*/ { 45.8, 62.5, 105.0 }, /* Ring (LL/LL128/Simple)*/ { 19.2, 44.6, 51.0 }, /* CollNet (LL/LL128/Simple)*/ { 45.8, 62.5, 105.0 } },
},
.bwRatio = {
/* 2 nodes */
{ /* Tree (LL/LL128/Simple)*/ { 0.12, 0.00, 1.41 }, /* Ring (LL/LL128/Simple)*/ { 0.12, 0.00, 1.00 }, /* CollNet (LL/LL128/Simple)*/ { 1.00, 1.00, 1.00 } },
{ /* Tree (LL/LL128/Simple)*/ { 0.12, 0.21, 1.41 }, /* Ring (LL/LL128/Simple)*/ { 0.12, 0.26, 1.00 }, /* CollNet (LL/LL128/Simple)*/ { 1.00, 1.00, 1.00 } },
/* more than 2 nodes */
{ /* Tree (LL/LL128/Simple)*/ { 0.12, 0.00, 1.05 }, /* Ring (LL/LL128/Simple)*/ { 0.12, 0.00, 1.00 }, /* CollNet (LL/LL128/Simple)*/ { 1.00, 1.00, 1.00 } },
{ /* Tree (LL/LL128/Simple)*/ { 0.12, 0.21, 1.05 }, /* Ring (LL/LL128/Simple)*/ { 0.12, 0.26, 1.00 }, /* CollNet (LL/LL128/Simple)*/ { 1.00, 1.00, 1.00 } },
},
.treeCorrectionFactor = {
{ 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.4, 1.0, 0.5, 0.8, 0.4, 0.3, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1, },
{ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, },
{ 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.3, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9, 0.9, 0.8, 0.8, 0.8, },
{ 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.1, 0.2, 0.1, 0.1, 0.1, 0.2, 0.3, 0.1, 0.2, 0.2, 0.2, 0.2, 0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, },
{ 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.4, 0.1, 0.1, 0.2, 0.3, 0.4, 0.2, 0.3, 0.3, 0.4, 0.3, 0.3, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, },
{ 0.1, 0.5, 0.1, 0.4, 0.1, 0.1, 0.2, 0.1, 1.0, 0.3, 0.1, 0.1, 0.1, 1.0, 0.6, 1.0, 1.0, 1.0, 1.0, 0.9, 0.8, 0.7, 0.6, 0.7, 0.6, 0.8, 0.8, },
},
.ringCorrectionFactor = {
{ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.6, 0.2, 0.2, 0.2, 0.2, 0.2, 0.1, 0.2, 0.2, 0.3, 0.3, 0.3, 0.2, 0.2, 0.2, 0.2, 0.2, 0.1, },
{ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, },
{ 0.6, 0.4, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 1.0, 0.8, 1.0, 1.0, 1.0, 0.7, 0.2, 0.3, 0.4, 0.5, 0.6, 0.8, 0.9, 1.0, 1.0, 1.0, 1.0, 1.0, },
{ 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.4, 0.1, 0.3, 0.1, 0.1, 0.2, 0.2, 0.3, 0.1, 0.2, 0.3, 0.3, 0.2, 0.2, 0.1, 0.2, 0.1, 0.1, 0.1, 0.1, },
{ 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 0.2, 0.4, 1.0, 1.0, 1.0, 0.8, 0.6, 0.2, 0.4, 0.6, 0.5, 0.5, 0.4, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, },
{ 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 1.0, 0.2, 0.9, 0.9, 0.8, 1.0, 0.4, 0.6, 0.7, 0.9, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, },
},
};
@@ -362,9 +362,16 @@ ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCom
for (int c=0; c<NCCL_NUM_FUNCTIONS; c++) for (int a=0; a<NCCL_NUM_ALGORITHMS; a++) for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
int pEnable = protoEnable[p];
if (pEnable == 2 && p == NCCL_PROTO_LL128) {
#if defined(__HIP_PLATFORM_HCC__) || defined(__HCC__) || defined(__HIPCC__)
// Enable LL128 by default only on gfx90a with available tuning table
pEnable = (graphs[a]->typeInter <= PATH_PXB) && graphs[a]->typeIntra <= PATH_NVL &&
(comm->topo->nodes[GPU].nodes[0].gpu.gcn == 910 && comm->topo->ll128Enabled) ? 1 : 0;
#else
// Enable LL128 by default only on Volta/Ampere+NVLink. Other cases are not tested and may cause silent data corruption.
pEnable = (graphs[a]->typeInter <= PATH_PXB) && graphs[a]->typeIntra <= PATH_NVL &&
((minCompCap == 70 && maxCompCap == 70) || (minCompCap == 80 && maxCompCap == 80)) ? 1 : 0;
#endif
if (comm->rank == 0 && c == 0 && a == 0) INFO(NCCL_INIT, "Using tuning table %d with LL128 %s", comm->topo->tuning, pEnable ? "enabled" : "disabled");
}
if (pEnable == 0) comm->bandwidths[c][a][p] = 0;
// Only disable algo for Allreduce since others only have one
+13 -1
Parādīt failu
@@ -32,13 +32,25 @@ struct ncclDevRedOpFull {
#define NCCL_KERN_NAME(func, algo, proto, devredop, type) \
ncclKernel_##func##_##algo##_##proto##_##devredop##_##type
#define NCCL_KERN_NAME_DEBUG(func, algo, proto, devredop, type) \
ncclKernelDebug_##func##_##algo##_##proto##_##devredop##_##type
#define NCCL_KERN_NAME_LL128(func, algo, proto, devredop, type) \
ncclKernelLL128_##func##_##algo##_##proto##_##devredop##_##type
#define NCCL_KERN_NAME_LL128_DEBUG(func, algo, proto, devredop, type) \
ncclKernelLL128Debug_##func##_##algo##_##proto##_##devredop##_##type
#define NCCL_IMPL_NAME(func, algo, proto) \
nccl##func##algo##proto
/* Declare all collective operations */
#define DECL5(func, algo, proto, devredop, type) \
extern __device__ __attribute__((noinline)) void NCCL_FUNC_NAME(func, algo, proto, devredop, type)(); \
extern __global__ void NCCL_KERN_NAME(func, algo, proto, devredop, type)(struct ncclDevComm* comm, struct ncclWorkElem c); \
extern __global__ void NCCL_KERN_NAME(func, algo, proto, devredop, type)(struct ncclDevComm* comm); \
extern __global__ void NCCL_KERN_NAME_DEBUG(func, algo, proto, devredop, type)(struct ncclDevComm* comm); \
extern __global__ void NCCL_KERN_NAME_LL128(func, algo, proto, devredop, type)(struct ncclDevComm* comm); \
extern __global__ void NCCL_KERN_NAME_LL128_DEBUG(func, algo, proto, devredop, type)(struct ncclDevComm* comm);
#define CONCAT(a,b) a##b
#define MACRO_IF(cond, t, f) CONCAT(MACRO_IF_, cond)(t, f)
+1 -2
Parādīt failu
@@ -197,8 +197,7 @@ struct ncclComm {
int* intraCudaDevs;
int* intraCGMode; // Whether we can use CUDA9 CGMD or not
int* intraCC; // Only to check all have the same ComputeCap and disable CGMode if not
struct ncclWorkElem args;
void* argsptrs[2];
void* argsptrs[1];
struct ncclProxyState proxyState;
+3 -3
Parādīt failu
@@ -76,18 +76,18 @@ union ncclLLFifoLine {
// Make sure the clean mask will last for at least NCCL_NSTEPS
static_assert(NCCL_LL_CLEAN_MASK % NCCL_STEPS == 0, "Invalid NCCL_LL_CLEAN_MASK value");
#define NCCL_LL128_LINESIZE 128
#define NCCL_LL128_LINESIZE 64
#define NCCL_LL128_LINEELEMS (NCCL_LL128_LINESIZE/sizeof(uint64_t))
#define NCCL_LL128_DATAELEMS (NCCL_LL128_LINEELEMS-1)
#define NCCL_LL128_MAX_NTHREADS 256
#define NCCL_LL128_ELEMS_PER_THREAD 120
#define NCCL_LL128_ELEMS_PER_THREAD 28
// Receiving from up to 3 sources is more compute intensive than sending
// to 3 dests. Use 70% for reduce and 30% for bcast.
#define NCCL_LL128_SPLIT(nt) ((nt*7/(10*32))*32)
#define NCCL_LL128_SHMEM_ELEMS_PER_THREAD 2
#define NCCL_LL128_SHMEM_ELEMS_PER_THREAD 4
#define NCCL_LL128_SHMEM_SIZE (NCCL_LL128_SHMEM_ELEMS_PER_THREAD*NCCL_LL128_MAX_NTHREADS)
#define NCCL_DIRECT_WRITE 0x01
+1
Parādīt failu
@@ -16,6 +16,7 @@
#define NCCL_AGG_CHANNEL_SIZE (1LL << 21) /* 2 MiB, ideal per-channel size to fully utilize bandwidth */
size_t ncclKernMaxLocalSize();
size_t ncclKernLocalSize(int i);
ncclResult_t ncclKernSetSharedMemoryCarveout(int carveOut);
ncclResult_t ncclEnqueueCheck(struct ncclInfo* info);
ncclResult_t ncclCpuBarrierIn(struct ncclComm* comm, int* isLast);
+8 -2
Parādīt failu
@@ -333,6 +333,7 @@ static ncclResult_t commFree(ncclComm_t comm) {
RCCL_PARAM(CliqueIgnoreTopo, "CLIQUE_IGNORE_TOPO", 0);
RCCL_PARAM(P2pNetDisable, "P2P_NET_DISABLE", 0);
RCCL_PARAM(PivotAlltoallEnable, "PIVOT_ALLTOALL_ENABLE", 1);
RCCL_PARAM(LL128ForceEnable, "LL128_FORCE_ENABLE", 0);
NCCL_PARAM(AggChannelSize, "AGG_CHANNEL_SIZE", -2);
NCCL_PARAM(DisableGraphHelper, "GRAPH_HELPER_DISABLE", 0);
NCCL_PARAM(GraphRegister, "GRAPH_REGISTER", 0);
@@ -383,7 +384,6 @@ static ncclResult_t commAlloc(ncclComm_t* comret, int ndev, int rank, int virtua
comm->p2pOpCount = 0;
comm->argsptrs[0] = &comm->devComm;
comm->argsptrs[1] = &comm->args;
#ifdef ENABLE_PROFILING
NCCLCHECK(ncclCudaCalloc(&comm->hostDevComm.devProf, MAXCHANNELS*PROFILE_NUM_LAUNCHES));
#endif
@@ -703,6 +703,8 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
// init Pivot A2A related fields
comm->topo->pivotA2AEnabled = false;
comm->topo->pivotA2ANumBiRings = 0;
// LL128
comm->topo->ll128Enabled = false;
// Compute paths between GPUs and NICs
NCCLCHECK(ncclTopoComputePaths(comm->topo, comm->peerInfo));
// Remove inaccessible GPUs and unused NICs
@@ -845,6 +847,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
struct ncclGraphInfo collNet;
struct ncclTopoRanks topoRanks;
bool pivotA2AEnabled;
bool ll128Enabled;
} *allGather3Data;
NCCLCHECK(ncclCalloc(&allGather3Data, nranks));
@@ -893,6 +896,8 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
allGather3Data[rank].collNet.typeInter = collNetGraph.typeInter;
allGather3Data[rank].collNetSupport = comm->collNetSupport;
allGather3Data[rank].pivotA2AEnabled = comm->topo->pivotA2AEnabled && rcclParamPivotAlltoallEnable();
comm->topo->ll128Enabled = comm->topo->ll128Enabled || rcclParamLL128ForceEnable();
allGather3Data[rank].ll128Enabled = comm->topo->ll128Enabled;
comm->nChannels = (comm->topo->nodes[GPU].count != comm->topo->nRanks && comm->topo->nodes[NET].count)
? std::min(treeGraph.nChannels, ringGraph.nChannels) : ringGraph.nChannels;
@@ -979,6 +984,7 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
collNetGraph.typeInter = std::max(allGather3Data[i].collNet.typeInter, collNetGraph.typeInter);
comm->collNetSupport = std::min(allGather3Data[i].collNetSupport, comm->collNetSupport);
comm->topo->pivotA2AEnabled = comm->topo->pivotA2AEnabled && allGather3Data[i].pivotA2AEnabled;
comm->topo->ll128Enabled = comm->topo->ll128Enabled && allGather3Data[i].ll128Enabled;
}
comm->nChannels = treeGraph.nChannels = ringGraph.nChannels =
@@ -1239,7 +1245,7 @@ ncclResult_t ncclCommInitRankSync(ncclComm_t* newcomm, int nranks, ncclUniqueId
NCCLCHECKGOTO(initTransportsRank(*newcomm, &commId), res, cleanup);
NCCLCHECKGOTO(devCommSetup(*newcomm), res, cleanup);
INFO(NCCL_INIT,"comm %p rank %d nranks %d cudaDev %d busId %lx localSize %ld used %ld bytes - Init COMPLETE", *newcomm, myrank, nranks, (*newcomm)->cudaDev, (*newcomm)->busId, maxLocalSizeBytes, allocTracker[(*newcomm)->cudaDev].totalAllocSize);
INFO(NCCL_INIT,"comm %p rank %d nranks %d cudaDev %d busId %lx localSize %ld used %ld bytes - Init COMPLETE", *newcomm, myrank, nranks, (*newcomm)->cudaDev, (*newcomm)->busId, ncclKernLocalSize(ncclGetKernelIndex(*newcomm)), allocTracker[(*newcomm)->cudaDev].totalAllocSize);
return ncclSuccess;
cleanup: