From d8a06589c99c16376e06fc3ccc799a7eb3e28fec Mon Sep 17 00:00:00 2001 From: Changpeng Fang Date: Tue, 8 Oct 2019 09:24:49 -0700 Subject: [PATCH] Tuning the inline and unroll to reduce the scratch usage Summary: 1. remove the noinline attribute for AllReduceThreeKernel; 2. change AUTPUNROLL for tree functions to 1 or 2; Combining 1 and 2 will reduce the scratch usage from 1256 to 952 [ROCm/rccl commit: eec319038e88920693484e997880add000e0c7cd] --- projects/rccl/src/collectives/device/all_reduce.h | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/projects/rccl/src/collectives/device/all_reduce.h b/projects/rccl/src/collectives/device/all_reduce.h index f319b4333e..9810990cc8 100644 --- a/projects/rccl/src/collectives/device/all_reduce.h +++ b/projects/rccl/src/collectives/device/all_reduce.h @@ -102,8 +102,7 @@ __device__ void ncclAllReduceRingKernel(struct CollectiveArgs* args) { #endif } -template -__attribute__((noinline)) +template __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) { const int tid = threadIdx.x; const int nthreads = blockDim.x; @@ -122,7 +121,7 @@ __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) { do { // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local) - ncclPrimitives prims(tid, nthreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount); + ncclPrimitives<1, 1, 1, T, NCCL_MAX_TREE_ARITY, 1, FUNC> prims(tid, nthreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { // Up ssize_t offset = gridOffset + bid*chunkSize; @@ -139,7 +138,7 @@ __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) { do { // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local) - ncclPrimitives prims(tid, nthreads, &tree->up, tree->down, NULL, stepSize, channel, comm, args->opCount); + ncclPrimitives<1, 1, 1, T, 1, NCCL_MAX_TREE_ARITY, FUNC> prims(tid, nthreads, &tree->up, tree->down, NULL, stepSize, channel, comm, args->opCount); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { // Down ssize_t offset = gridOffset + bid*chunkSize;