Merge pull request #132 from wenkaidu/reduce_kernels

Only generate kernels for sum and copy
This commit is contained in:
Wenkai Du
2019-09-26 16:14:45 -07:00
committato da GitHub
+9 -5
Vedi File
@@ -152,7 +152,6 @@ __device__ void NCCL_COLL_NAME(coll, op, dtype)(struct CollectiveArgs* args) { \
coll##Kernel<COLL_UNROLL, ncclFunc<ctype>, ctype>(args); \
}
#if NCCL_OP == 0
/* Kernels with the first operation inlined */
#define IMPL_COLL_KERN(coll, op, ncclFunc, dtype, ctype, fIndex) \
__launch_bounds__(MAXTHREADS+WARP_SIZE, 1) \
@@ -195,15 +194,20 @@ __global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl firstColl) { \
load_coll(c, channel->devCollectives+nextIndex, tid, &abortCount); \
} \
}
#else
#define IMPL_COLL_KERN(coll, op, ncclFunc, dtype, ctype, fIndex)
#endif
#define IMPL_COLL_KERN_sum(coll, op, ncclFunc, dtype, ctype, fIndex) \
IMPL_COLL_KERN(coll, op, ncclFunc, dtype, ctype, fIndex)
#define IMPL_COLL_KERN_copy(coll, op, ncclFunc, dtype, ctype, fIndex) \
IMPL_COLL_KERN(coll, op, ncclFunc, dtype, ctype, fIndex)
#define IMPL_COLL_KERN_prod(coll, op, ncclFunc, dtype, ctype, fIndex)
#define IMPL_COLL_KERN_min(coll, op, ncclFunc, dtype, ctype, fIndex)
#define IMPL_COLL_KERN_max(coll, op, ncclFunc, dtype, ctype, fIndex)
// Only generate inline kernels for LL
#define IMPL_COLL4(coll, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, al) \
IMPL_COLL_FUNC(coll, op, ncclFunc, dtype, ctype) \
IMPL_COLL_FUNC(coll##LL, op, ncclFunc, dtype, ctype) \
IMPL_COLL_KERN(coll##LL, op, ncclFunc, dtype, ctype, FUNC_INDEX(ncclColl, ncclOp, ncclType, 1, al)) \
IMPL_COLL_KERN_##op(coll##LL, op, ncclFunc, dtype, ctype, FUNC_INDEX(ncclColl, ncclOp, ncclType, 1, al)) \
#define IMPL_COLL3(coll, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType) \
IMPL_COLL4(coll##Ring, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, 0)