diff --git a/src/collectives/device/common.h b/src/collectives/device/common.h index fd26814b0f..2086c1c057 100644 --- a/src/collectives/device/common.h +++ b/src/collectives/device/common.h @@ -152,7 +152,6 @@ __device__ void NCCL_COLL_NAME(coll, op, dtype)(struct CollectiveArgs* args) { \ coll##Kernel, ctype>(args); \ } -#if NCCL_OP == 0 /* Kernels with the first operation inlined */ #define IMPL_COLL_KERN(coll, op, ncclFunc, dtype, ctype, fIndex) \ __launch_bounds__(MAXTHREADS+WARP_SIZE, 1) \ @@ -195,15 +194,20 @@ __global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl firstColl) { \ load_coll(c, channel->devCollectives+nextIndex, tid, &abortCount); \ } \ } -#else -#define IMPL_COLL_KERN(coll, op, ncclFunc, dtype, ctype, fIndex) -#endif + +#define IMPL_COLL_KERN_sum(coll, op, ncclFunc, dtype, ctype, fIndex) \ + IMPL_COLL_KERN(coll, op, ncclFunc, dtype, ctype, fIndex) +#define IMPL_COLL_KERN_copy(coll, op, ncclFunc, dtype, ctype, fIndex) \ + IMPL_COLL_KERN(coll, op, ncclFunc, dtype, ctype, fIndex) +#define IMPL_COLL_KERN_prod(coll, op, ncclFunc, dtype, ctype, fIndex) +#define IMPL_COLL_KERN_min(coll, op, ncclFunc, dtype, ctype, fIndex) +#define IMPL_COLL_KERN_max(coll, op, ncclFunc, dtype, ctype, fIndex) // Only generate inline kernels for LL #define IMPL_COLL4(coll, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, al) \ IMPL_COLL_FUNC(coll, op, ncclFunc, dtype, ctype) \ IMPL_COLL_FUNC(coll##LL, op, ncclFunc, dtype, ctype) \ - IMPL_COLL_KERN(coll##LL, op, ncclFunc, dtype, ctype, FUNC_INDEX(ncclColl, ncclOp, ncclType, 1, al)) \ + IMPL_COLL_KERN_##op(coll##LL, op, ncclFunc, dtype, ctype, FUNC_INDEX(ncclColl, ncclOp, ncclType, 1, al)) \ #define IMPL_COLL3(coll, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType) \ IMPL_COLL4(coll##Ring, op, ncclFunc, dtype, ctype, ncclColl, ncclOp, ncclType, 0)