From e7eff47be4160c4e2b2c916db8e714f3dc41bcb2 Mon Sep 17 00:00:00 2001 From: Wenkai Du Date: Fri, 15 May 2020 14:15:40 -0700 Subject: [PATCH] Revert "Tuning the inline and unroll to reduce the scratch usage" This reverts commit d8a06589c99c16376e06fc3ccc799a7eb3e28fec. [ROCm/rccl commit: ca493a6b51c8d73be7a77639dbc708cbf9d3539f] --- projects/rccl/src/collectives/device/all_reduce.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/projects/rccl/src/collectives/device/all_reduce.h b/projects/rccl/src/collectives/device/all_reduce.h index 095367fd42..8d5934b5e3 100644 --- a/projects/rccl/src/collectives/device/all_reduce.h +++ b/projects/rccl/src/collectives/device/all_reduce.h @@ -102,7 +102,7 @@ __device__ void ncclAllReduceRingKernel(struct CollectiveArgs* args) { #endif } -template +template __attribute__((noinline)) __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) { const int tid = threadIdx.x; @@ -128,7 +128,7 @@ __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) { struct ncclTree* tree = &channel->treeUp; // Reduce : max number of recv is 3, max number of send is 1 (binary tree + local) ncclPrimitivesRecvData recvData; - ncclPrimitives<1, 1, 1, T, NCCL_MAX_TREE_ARITY, 1, FUNC> prims(tid, args->nThreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount, recvData); + ncclPrimitives prims(tid, args->nThreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount, recvData); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { // Up ssize_t offset = gridOffset + bid*chunkSize; @@ -147,7 +147,7 @@ __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) { struct ncclTree* tree = &channel->treeDn; // Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local) ncclPrimitivesSendData sendData; - ncclPrimitives<1, 1, 1, T, 1, NCCL_MAX_TREE_ARITY, FUNC> prims(tid, args->nThreads, &tree->up, tree->down, NULL, stepSize, channel, comm, args->opCount, sendData); + ncclPrimitives prims(tid, args->nThreads, &tree->up, tree->down, NULL, stepSize, channel, comm, args->opCount, sendData); for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) { // Down ssize_t offset = gridOffset + bid*chunkSize;