Merge pull request #136 from wenkaidu/tree
Enable tree kernels in build
[ROCm/rccl commit: 062c798c86]
Этот коммит содержится в:
@@ -9,7 +9,7 @@
|
||||
#ifndef NCCL_COLLECTIVES_H_
|
||||
#define NCCL_COLLECTIVES_H_
|
||||
|
||||
#define FUNC_INDEX(coll, redop, dtype, ll, al) ((((coll*ncclNumOps + redop)*ncclNumTypes) + dtype)*2+ll)
|
||||
#define FUNC_INDEX(coll, redop, dtype, ll, al) ((((((coll)*ncclNumOps + (redop))*ncclNumTypes) + (dtype))*2+(al))*2+(ll))
|
||||
|
||||
#define NCCL_COLL_NAME(coll, op, dtype) \
|
||||
coll##_##op##_##dtype
|
||||
@@ -27,7 +27,8 @@
|
||||
DECL_COLL5(coll##LL, op, dtype)
|
||||
|
||||
#define DECL_COLL3(coll, op, dtype) \
|
||||
DECL_COLL4(coll##Ring, op, dtype)
|
||||
DECL_COLL4(coll##Ring, op, dtype) \
|
||||
DECL_COLL4(coll##Tree, op, dtype)
|
||||
|
||||
#define DECL_COLL2(coll, op) \
|
||||
DECL_COLL3(coll, op, i8) \
|
||||
|
||||
@@ -40,7 +40,8 @@ static inline __device__ void exitIfAbortBarrier(int abort) {
|
||||
NCCL_COLL_NAME(coll##LL, op, dtype)
|
||||
|
||||
#define NCCL_FUNC4(coll, op, dtype) \
|
||||
NCCL_FUNC5(coll##Ring, op, dtype)
|
||||
NCCL_FUNC5(coll##Ring, op, dtype), \
|
||||
NCCL_FUNC5(coll##Tree, op, dtype)
|
||||
|
||||
// Must be consistent with ncclDataType_t
|
||||
#define NCCL_FUNCS3A(coll, op) \
|
||||
@@ -120,16 +121,20 @@ struct Caller<f, f + 1>{
|
||||
inline
|
||||
__device__
|
||||
void NCCL_CALL_FUNCTIONS(struct ncclColl* const c) noexcept {
|
||||
if (c->funcIndex < 72) {
|
||||
if (c->funcIndex % 2) ncclBroadcastRingLL_copy_i8(&c->args);
|
||||
else ncclBroadcastRing_copy_i8(&c->args);
|
||||
if (c->funcIndex < 144) {
|
||||
if (c->funcIndex % 4 == 0) ncclBroadcastRing_copy_i8(&c->args);
|
||||
else if (c->funcIndex % 4 == 1) ncclBroadcastRingLL_copy_i8(&c->args);
|
||||
else if (c->funcIndex % 4 == 2) ncclBroadcastTree_copy_i8(&c->args);
|
||||
else ncclBroadcastTreeLL_copy_i8(&c->args);
|
||||
}
|
||||
else if (c->funcIndex < 144) Caller<72, 144>::call(c);
|
||||
else if (c->funcIndex < 216) {
|
||||
if (c->funcIndex % 2) ncclAllGatherRingLL_copy_i8(&c->args);
|
||||
else ncclAllGatherRing_copy_i8(&c->args);
|
||||
else if (c->funcIndex < 288) Caller<144, 288>::call(c);
|
||||
else if (c->funcIndex < 432) {
|
||||
if (c->funcIndex % 4 == 0) ncclAllGatherRing_copy_i8(&c->args);
|
||||
else if (c->funcIndex % 4 == 1) ncclAllGatherRingLL_copy_i8(&c->args);
|
||||
else if (c->funcIndex % 4 == 2) ncclAllGatherTree_copy_i8(&c->args);
|
||||
else ncclAllGatherTreeLL_copy_i8(&c->args);
|
||||
}
|
||||
else Caller<216, 360>::call(c);
|
||||
else Caller<432, 720>::call(c);
|
||||
}
|
||||
|
||||
static __device__ void load_parallel(void* dst, void* src, size_t size, int tid, uint32_t* abortCount) {
|
||||
@@ -210,7 +215,8 @@ __global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl firstColl) { \
|
||||
IMPL_COLL_KERN_##op(coll##LL, op, ncclFunc, dtype, ctype, FUNC_INDEX(ncclColl, ncclOp, ncclType, 1, al)) \
|
||||
|
||||
#define IMPL_COLL3(coll, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType) \
|
||||
IMPL_COLL4(coll##Ring, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, 0)
|
||||
IMPL_COLL4(coll##Ring, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, 0) \
|
||||
IMPL_COLL4(coll##Tree, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, 1)
|
||||
|
||||
#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \
|
||||
IMPL_COLL3(coll, op, ncclFunc, i8, int8_t, ncclColl, ncclOp, ncclInt8) \
|
||||
|
||||
@@ -17,7 +17,8 @@
|
||||
NCCL_KERN_NAME(coll##LL, op, dtype)
|
||||
|
||||
#define NCCL_FUNC4(coll, op, dtype) \
|
||||
NCCL_FUNC5(coll##Ring, op, dtype)
|
||||
NCCL_FUNC5(coll##Ring, op, dtype), \
|
||||
NCCL_FUNC5(coll##Tree, op, dtype)
|
||||
|
||||
// Must be consistent with ncclDataType_t
|
||||
#define NCCL_FUNCS3A(coll, op) \
|
||||
@@ -55,7 +56,7 @@
|
||||
|
||||
typedef void(*ncclKern_t)(struct ncclColl);
|
||||
// Must be consistent with the ncclFuncSet enum
|
||||
static ncclKern_t const ncclKerns[ncclCollCount*ncclNumOps*ncclNumTypes*2] = {
|
||||
static ncclKern_t const ncclKerns[ncclCollCount*ncclNumOps*ncclNumTypes*2*2] = {
|
||||
NCCL_FUNCS2B(ncclBroadcast),
|
||||
NCCL_FUNCS2A(ncclReduce),
|
||||
NCCL_FUNCS2B(ncclAllGather),
|
||||
|
||||
Ссылка в новой задаче
Block a user