Merge pull request #136 from wenkaidu/tree

Enable tree kernels in build

[ROCm/rccl commit: 062c798c86]
Этот коммит содержится в:
Wenkai Du
2019-10-09 10:58:52 -07:00
коммит произвёл GitHub
родитель c4ed3d2e08 f86ee41415
Коммит fbcdfd8348
3 изменённых файлов: 22 добавлений и 14 удалений
+3 -2
Просмотреть файл
@@ -9,7 +9,7 @@
#ifndef NCCL_COLLECTIVES_H_
#define NCCL_COLLECTIVES_H_
#define FUNC_INDEX(coll, redop, dtype, ll, al) ((((coll*ncclNumOps + redop)*ncclNumTypes) + dtype)*2+ll)
#define FUNC_INDEX(coll, redop, dtype, ll, al) ((((((coll)*ncclNumOps + (redop))*ncclNumTypes) + (dtype))*2+(al))*2+(ll))
#define NCCL_COLL_NAME(coll, op, dtype) \
coll##_##op##_##dtype
@@ -27,7 +27,8 @@
DECL_COLL5(coll##LL, op, dtype)
#define DECL_COLL3(coll, op, dtype) \
DECL_COLL4(coll##Ring, op, dtype)
DECL_COLL4(coll##Ring, op, dtype) \
DECL_COLL4(coll##Tree, op, dtype)
#define DECL_COLL2(coll, op) \
DECL_COLL3(coll, op, i8) \
+16 -10
Просмотреть файл
@@ -40,7 +40,8 @@ static inline __device__ void exitIfAbortBarrier(int abort) {
NCCL_COLL_NAME(coll##LL, op, dtype)
#define NCCL_FUNC4(coll, op, dtype) \
NCCL_FUNC5(coll##Ring, op, dtype)
NCCL_FUNC5(coll##Ring, op, dtype), \
NCCL_FUNC5(coll##Tree, op, dtype)
// Must be consistent with ncclDataType_t
#define NCCL_FUNCS3A(coll, op) \
@@ -120,16 +121,20 @@ struct Caller<f, f + 1>{
inline
__device__
void NCCL_CALL_FUNCTIONS(struct ncclColl* const c) noexcept {
if (c->funcIndex < 72) {
if (c->funcIndex % 2) ncclBroadcastRingLL_copy_i8(&c->args);
else ncclBroadcastRing_copy_i8(&c->args);
if (c->funcIndex < 144) {
if (c->funcIndex % 4 == 0) ncclBroadcastRing_copy_i8(&c->args);
else if (c->funcIndex % 4 == 1) ncclBroadcastRingLL_copy_i8(&c->args);
else if (c->funcIndex % 4 == 2) ncclBroadcastTree_copy_i8(&c->args);
else ncclBroadcastTreeLL_copy_i8(&c->args);
}
else if (c->funcIndex < 144) Caller<72, 144>::call(c);
else if (c->funcIndex < 216) {
if (c->funcIndex % 2) ncclAllGatherRingLL_copy_i8(&c->args);
else ncclAllGatherRing_copy_i8(&c->args);
else if (c->funcIndex < 288) Caller<144, 288>::call(c);
else if (c->funcIndex < 432) {
if (c->funcIndex % 4 == 0) ncclAllGatherRing_copy_i8(&c->args);
else if (c->funcIndex % 4 == 1) ncclAllGatherRingLL_copy_i8(&c->args);
else if (c->funcIndex % 4 == 2) ncclAllGatherTree_copy_i8(&c->args);
else ncclAllGatherTreeLL_copy_i8(&c->args);
}
else Caller<216, 360>::call(c);
else Caller<432, 720>::call(c);
}
static __device__ void load_parallel(void* dst, void* src, size_t size, int tid, uint32_t* abortCount) {
@@ -210,7 +215,8 @@ __global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl firstColl) { \
IMPL_COLL_KERN_##op(coll##LL, op, ncclFunc, dtype, ctype, FUNC_INDEX(ncclColl, ncclOp, ncclType, 1, al)) \
#define IMPL_COLL3(coll, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType) \
IMPL_COLL4(coll##Ring, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, 0)
IMPL_COLL4(coll##Ring, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, 0) \
IMPL_COLL4(coll##Tree, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, 1)
#define IMPL_COLL2(coll, op, ncclFunc, ncclColl, ncclOp) \
IMPL_COLL3(coll, op, ncclFunc, i8, int8_t, ncclColl, ncclOp, ncclInt8) \
+3 -2
Просмотреть файл
@@ -17,7 +17,8 @@
NCCL_KERN_NAME(coll##LL, op, dtype)
#define NCCL_FUNC4(coll, op, dtype) \
NCCL_FUNC5(coll##Ring, op, dtype)
NCCL_FUNC5(coll##Ring, op, dtype), \
NCCL_FUNC5(coll##Tree, op, dtype)
// Must be consistent with ncclDataType_t
#define NCCL_FUNCS3A(coll, op) \
@@ -55,7 +56,7 @@
typedef void(*ncclKern_t)(struct ncclColl);
// Must be consistent with the ncclFuncSet enum
static ncclKern_t const ncclKerns[ncclCollCount*ncclNumOps*ncclNumTypes*2] = {
static ncclKern_t const ncclKerns[ncclCollCount*ncclNumOps*ncclNumTypes*2*2] = {
NCCL_FUNCS2B(ncclBroadcast),
NCCL_FUNCS2A(ncclReduce),
NCCL_FUNCS2B(ncclAllGather),