Merge pull request #210 from wenkaidu/unroll

Revert "Tuning the inline and unroll to reduce the scratch usage"
Этот коммит содержится в:
Wenkai Du
2020-05-15 15:35:27 -07:00
коммит произвёл GitHub
родитель c245f1507e ca493a6b51
Коммит af703877cf
+3 -3
Просмотреть файл
@@ -102,7 +102,7 @@ __device__ void ncclAllReduceRingKernel(struct CollectiveArgs* args) {
#endif
}
template<int UNUSED, class FUNC, typename T>
template<int UNROLL, class FUNC, typename T>
__attribute__((noinline))
__device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
@@ -128,7 +128,7 @@ __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) {
struct ncclTree* tree = &channel->treeUp;
// Reduce : max number of recv is 3, max number of send is 1 (binary tree + local)
ncclPrimitivesRecvData<T, NCCL_MAX_TREE_ARITY> recvData;
ncclPrimitives<1, 1, 1, T, NCCL_MAX_TREE_ARITY, 1, FUNC> prims(tid, args->nThreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount, recvData);
ncclPrimitives<UNROLL, 1, 1, T, NCCL_MAX_TREE_ARITY, 1, FUNC> prims(tid, args->nThreads, tree->down, &tree->up, NULL, stepSize, channel, comm, args->opCount, recvData);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
// Up
ssize_t offset = gridOffset + bid*chunkSize;
@@ -147,7 +147,7 @@ __device__ void ncclAllReduceTreeKernel(struct CollectiveArgs* args) {
struct ncclTree* tree = &channel->treeDn;
// Broadcast : max number of recv is 1, max number of send is 3 (binary tree + local)
ncclPrimitivesSendData<T, NCCL_MAX_TREE_ARITY> sendData;
ncclPrimitives<1, 1, 1, T, 1, NCCL_MAX_TREE_ARITY, FUNC> prims(tid, args->nThreads, &tree->up, tree->down, NULL, stepSize, channel, comm, args->opCount, sendData);
ncclPrimitives<UNROLL, 1, 1, T, 1, NCCL_MAX_TREE_ARITY, FUNC> prims(tid, args->nThreads, &tree->up, tree->down, NULL, stepSize, channel, comm, args->opCount, sendData);
for (ssize_t gridOffset = 0; gridOffset < size; gridOffset += loopSize) {
// Down
ssize_t offset = gridOffset + bid*chunkSize;