diff --git a/projects/rccl/src/collectives/collectives.h b/projects/rccl/src/collectives/collectives.h index c56d90888e..63fcfd2017 100644 --- a/projects/rccl/src/collectives/collectives.h +++ b/projects/rccl/src/collectives/collectives.h @@ -9,7 +9,7 @@ #ifndef NCCL_COLLECTIVES_H_ #define NCCL_COLLECTIVES_H_ -#define FUNC_INDEX(coll, redop, dtype, ll, al) ((((coll*ncclNumOps + redop)*ncclNumTypes) + dtype)*2+ll) +#define FUNC_INDEX(coll, redop, dtype, ll, al) ((((((coll)*ncclNumOps + (redop))*ncclNumTypes) + (dtype))*2+(al))*2+(ll)) #define NCCL_COLL_NAME(coll, op, dtype) \ coll##_##op##_##dtype @@ -27,7 +27,8 @@ DECL_COLL5(coll##LL, op, dtype) #define DECL_COLL3(coll, op, dtype) \ - DECL_COLL4(coll##Ring, op, dtype) + DECL_COLL4(coll##Ring, op, dtype) \ + DECL_COLL4(coll##Tree, op, dtype) #define DECL_COLL2(coll, op) \ DECL_COLL3(coll, op, i8) \ diff --git a/projects/rccl/src/collectives/device/common.h b/projects/rccl/src/collectives/device/common.h index 2086c1c057..863198180e 100644 --- a/projects/rccl/src/collectives/device/common.h +++ b/projects/rccl/src/collectives/device/common.h @@ -40,7 +40,8 @@ static inline __device__ void exitIfAbortBarrier(int abort) { NCCL_COLL_NAME(coll##LL, op, dtype) #define NCCL_FUNC4(coll, op, dtype) \ - NCCL_FUNC5(coll##Ring, op, dtype) + NCCL_FUNC5(coll##Ring, op, dtype), \ + NCCL_FUNC5(coll##Tree, op, dtype) // Must be consistent with ncclDataType_t #define NCCL_FUNCS3A(coll, op) \ @@ -120,16 +121,20 @@ struct Caller{ inline __device__ void NCCL_CALL_FUNCTIONS(struct ncclColl* const c) noexcept { - if (c->funcIndex < 72) { - if (c->funcIndex % 2) ncclBroadcastRingLL_copy_i8(&c->args); - else ncclBroadcastRing_copy_i8(&c->args); + if (c->funcIndex < 144) { + if (c->funcIndex % 4 == 0) ncclBroadcastRing_copy_i8(&c->args); + else if (c->funcIndex % 4 == 1) ncclBroadcastRingLL_copy_i8(&c->args); + else if (c->funcIndex % 4 == 2) ncclBroadcastTree_copy_i8(&c->args); + else ncclBroadcastTreeLL_copy_i8(&c->args); } - else if (c->funcIndex < 144) Caller<72, 144>::call(c); - else if (c->funcIndex < 216) { - if (c->funcIndex % 2) ncclAllGatherRingLL_copy_i8(&c->args); - else ncclAllGatherRing_copy_i8(&c->args); + else if (c->funcIndex < 288) Caller<144, 288>::call(c); + else if (c->funcIndex < 432) { + if (c->funcIndex % 4 == 0) ncclAllGatherRing_copy_i8(&c->args); + else if (c->funcIndex % 4 == 1) ncclAllGatherRingLL_copy_i8(&c->args); + else if (c->funcIndex % 4 == 2) ncclAllGatherTree_copy_i8(&c->args); + else ncclAllGatherTreeLL_copy_i8(&c->args); } - else Caller<216, 360>::call(c); + else Caller<432, 720>::call(c); } static __device__ void load_parallel(void* dst, void* src, size_t size, int tid, uint32_t* abortCount) { @@ -210,7 +215,8 @@ __global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl firstColl) { \ IMPL_COLL_KERN_##op(coll##LL, op, ncclFunc, dtype, ctype, FUNC_INDEX(ncclColl, ncclOp, ncclType, 1, al)) \ #define IMPL_COLL3(coll, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType) \ - IMPL_COLL4(coll##Ring, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, 0) + IMPL_COLL4(coll##Ring, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, 0) \ + IMPL_COLL4(coll##Tree, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, 1) #define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \ IMPL_COLL3(coll, op, ncclFunc, i8, int8_t, ncclColl, ncclOp, ncclInt8) \ diff --git a/projects/rccl/src/enqueue.cc b/projects/rccl/src/enqueue.cc index 0c7b897ec4..a6eb484d4e 100644 --- a/projects/rccl/src/enqueue.cc +++ b/projects/rccl/src/enqueue.cc @@ -17,7 +17,8 @@ NCCL_KERN_NAME(coll##LL, op, dtype) #define NCCL_FUNC4(coll, op, dtype) \ - NCCL_FUNC5(coll##Ring, op, dtype) + NCCL_FUNC5(coll##Ring, op, dtype), \ + NCCL_FUNC5(coll##Tree, op, dtype) // Must be consistent with ncclDataType_t #define NCCL_FUNCS3A(coll, op) \ @@ -55,7 +56,7 @@ typedef void(*ncclKern_t)(struct ncclColl); // Must be consistent with the ncclFuncSet enum -static ncclKern_t const ncclKerns[ncclCollCount*ncclNumOps*ncclNumTypes*2] = { +static ncclKern_t const ncclKerns[ncclCollCount*ncclNumOps*ncclNumTypes*2*2] = { NCCL_FUNCS2B(ncclBroadcast), NCCL_FUNCS2A(ncclReduce), NCCL_FUNCS2B(ncclAllGather),