From f2736a4fb3f5dce926520f02a9c607e1c72c6943 Mon Sep 17 00:00:00 2001 From: Edgar Gabriel Date: Mon, 3 Oct 2022 15:12:57 +0000 Subject: [PATCH] introduce a hw topology aware bintree for hayabusa architecture. [ROCm/rccl commit: e645b02cd84e0215698b69f55130ec6a601df900] --- projects/rccl/src/enqueue.cc | 2 +- projects/rccl/src/graph/connect.cc | 197 +++++++++++++++++++++++++++++ projects/rccl/src/include/graph.h | 1 + projects/rccl/src/init.cc | 6 +- 4 files changed, 204 insertions(+), 2 deletions(-) diff --git a/projects/rccl/src/enqueue.cc b/projects/rccl/src/enqueue.cc index ddaba20d9d..481e163d48 100644 --- a/projects/rccl/src/enqueue.cc +++ b/projects/rccl/src/enqueue.cc @@ -1224,7 +1224,7 @@ comp_next: NCCLCHECK(getPatternInfo(info)); NCCLCHECK(getLoopInfo(info)); if (info->comm->topo->pivotA2ANumBiRings == 3 ) { - if (ncclTypeSize(info->datatype)*info->count > 65536) { + if (ncclTypeSize(info->datatype)*info->count > 131072) { work->pad_0 = 1; } else { work->pad_0 = 2; diff --git a/projects/rccl/src/graph/connect.cc b/projects/rccl/src/graph/connect.cc index 7359b70ecb..d60539b11b 100644 --- a/projects/rccl/src/graph/connect.cc +++ b/projects/rccl/src/graph/connect.cc @@ -163,6 +163,203 @@ ncclResult_t ncclBinaryTreePostset(struct ncclComm* comm, return ncclSuccess; } +#define NUM_HAYABUSA_TREES 2 +static bool hayabusa_tree_matrix_is_init=false; +static int hayabusa_tree_matrix[NUM_HAYABUSA_TREES][16][4]; + +static void hayabusa_tree_matrix_init() +{ + if (hayabusa_tree_matrix_is_init) + return; + + // index = rank of proc, child0, child1, child2, parent + // channel 0: root is 15 + hayabusa_tree_matrix[0][0][0] = 1; + hayabusa_tree_matrix[0][0][1] = -1; + hayabusa_tree_matrix[0][0][2] = -1; + hayabusa_tree_matrix[0][0][3] = 4; + + hayabusa_tree_matrix[0][1][0] = -1; + hayabusa_tree_matrix[0][1][1] = -1; + hayabusa_tree_matrix[0][1][2] = -1; + hayabusa_tree_matrix[0][1][3] = 0; + + hayabusa_tree_matrix[0][2][0] = 3; + hayabusa_tree_matrix[0][2][1] = -1; + hayabusa_tree_matrix[0][2][2] = -1; + hayabusa_tree_matrix[0][2][3] = 6; + + hayabusa_tree_matrix[0][3][0] = -1; + hayabusa_tree_matrix[0][3][1] = -1; + hayabusa_tree_matrix[0][3][2] = -1; + hayabusa_tree_matrix[0][3][3] = 2; + + hayabusa_tree_matrix[0][4][0] = 0; + hayabusa_tree_matrix[0][4][1] = -1; + hayabusa_tree_matrix[0][4][2] = -1; + hayabusa_tree_matrix[0][4][3] = 5; + + hayabusa_tree_matrix[0][5][0] = 4; + hayabusa_tree_matrix[0][5][1] = -1; + hayabusa_tree_matrix[0][5][2] = -1; + hayabusa_tree_matrix[0][5][3] = 14; + + hayabusa_tree_matrix[0][6][0] = 2; + hayabusa_tree_matrix[0][6][1] = 7; + hayabusa_tree_matrix[0][6][2] = -1; + hayabusa_tree_matrix[0][6][3] = 14; + + hayabusa_tree_matrix[0][7][0] = -1; + hayabusa_tree_matrix[0][7][1] = -1; + hayabusa_tree_matrix[0][7][2] = -1; + hayabusa_tree_matrix[0][7][3] = 6; + + hayabusa_tree_matrix[0][8][0] = -1; + hayabusa_tree_matrix[0][8][1] = -1; + hayabusa_tree_matrix[0][8][2] = -1; + hayabusa_tree_matrix[0][8][3] = 9; + + hayabusa_tree_matrix[0][9][0] = 13; + hayabusa_tree_matrix[0][9][1] = 8; + hayabusa_tree_matrix[0][9][2] = -1; + hayabusa_tree_matrix[0][9][3] = 11; + + hayabusa_tree_matrix[0][10][0] = -1; + hayabusa_tree_matrix[0][10][1] = -1; + hayabusa_tree_matrix[0][10][2] = -1; + hayabusa_tree_matrix[0][10][3] = 11; + + hayabusa_tree_matrix[0][11][0] = 9; + hayabusa_tree_matrix[0][11][1] = 10; + hayabusa_tree_matrix[0][11][2] = -1; + hayabusa_tree_matrix[0][11][3] = 15; + + hayabusa_tree_matrix[0][12][0] = -1; + hayabusa_tree_matrix[0][12][1] = -1; + hayabusa_tree_matrix[0][12][2] = -1; + hayabusa_tree_matrix[0][12][3] = 13; + + hayabusa_tree_matrix[0][13][0] = 12; + hayabusa_tree_matrix[0][13][1] = -1; + hayabusa_tree_matrix[0][13][2] = -1; + hayabusa_tree_matrix[0][13][3] = 9; + + hayabusa_tree_matrix[0][14][0] = 5; + hayabusa_tree_matrix[0][14][1] = 6; + hayabusa_tree_matrix[0][14][2] = -1; + hayabusa_tree_matrix[0][14][3] = 15; + + hayabusa_tree_matrix[0][15][0] = 14; + hayabusa_tree_matrix[0][15][1] = 11; + hayabusa_tree_matrix[0][15][2] = -1; + hayabusa_tree_matrix[0][15][3] = -1; + + //Channel 1: root is 6 + hayabusa_tree_matrix[1][0][0] = -1; + hayabusa_tree_matrix[1][0][1] = -1; + hayabusa_tree_matrix[1][0][2] = -1; + hayabusa_tree_matrix[1][0][3] = 1; + + hayabusa_tree_matrix[1][1][0] = 5; + hayabusa_tree_matrix[1][1][1] = 0; + hayabusa_tree_matrix[1][1][2] = -1; + hayabusa_tree_matrix[1][1][3] = 3; + + hayabusa_tree_matrix[1][2][0] = -1; + hayabusa_tree_matrix[1][2][1] = -1; + hayabusa_tree_matrix[1][2][2] = -1; + hayabusa_tree_matrix[1][2][3] = 3; + + hayabusa_tree_matrix[1][3][0] = 1; + hayabusa_tree_matrix[1][3][1] = 2; + hayabusa_tree_matrix[1][3][2] = -1; + hayabusa_tree_matrix[1][3][3] = 7; + + hayabusa_tree_matrix[1][4][0] = -1; + hayabusa_tree_matrix[1][4][1] = -1; + hayabusa_tree_matrix[1][4][2] = -1; + hayabusa_tree_matrix[1][4][3] = 5; + + hayabusa_tree_matrix[1][5][0] = 4; + hayabusa_tree_matrix[1][5][1] = -1; + hayabusa_tree_matrix[1][5][2] = -1; + hayabusa_tree_matrix[1][5][3] = 1; + + hayabusa_tree_matrix[1][6][0] = 7; + hayabusa_tree_matrix[1][6][1] = 13; + hayabusa_tree_matrix[1][6][2] = -1; + hayabusa_tree_matrix[1][6][3] = -1; + + hayabusa_tree_matrix[1][7][0] = 3; + hayabusa_tree_matrix[1][7][1] = 15; + hayabusa_tree_matrix[1][7][2] = -1; + hayabusa_tree_matrix[1][7][3] = 6; + + hayabusa_tree_matrix[1][8][0] = 9; + hayabusa_tree_matrix[1][8][1] = -1; + hayabusa_tree_matrix[1][8][2] = -1; + hayabusa_tree_matrix[1][8][3] = 12; + + hayabusa_tree_matrix[1][9][0] = -1; + hayabusa_tree_matrix[1][9][1] = -1; + hayabusa_tree_matrix[1][9][2] = -1; + hayabusa_tree_matrix[1][9][3] = 8; + + hayabusa_tree_matrix[1][10][0] = -1; + hayabusa_tree_matrix[1][10][1] = -1; + hayabusa_tree_matrix[1][10][2] = -1; + hayabusa_tree_matrix[1][10][3] = 11; + + hayabusa_tree_matrix[1][11][0] = 10; + hayabusa_tree_matrix[1][11][1] = -1; + hayabusa_tree_matrix[1][11][2] = -1; + hayabusa_tree_matrix[1][11][3] = 15; + + hayabusa_tree_matrix[1][12][0] = 8; + hayabusa_tree_matrix[1][12][1] = -1; + hayabusa_tree_matrix[1][12][2] = -1; + hayabusa_tree_matrix[1][12][3] = 13; + + hayabusa_tree_matrix[1][13][0] = 12; + hayabusa_tree_matrix[1][13][1] = -1; + hayabusa_tree_matrix[1][13][2] = -1; + hayabusa_tree_matrix[1][13][3] = 6; + + hayabusa_tree_matrix[1][14][0] = -1; + hayabusa_tree_matrix[1][14][1] = -1; + hayabusa_tree_matrix[1][14][2] = -1; + hayabusa_tree_matrix[1][14][3] = 15; + + hayabusa_tree_matrix[1][15][0] = 11; + hayabusa_tree_matrix[1][15][1] = 14; + hayabusa_tree_matrix[1][15][2] = -1; + hayabusa_tree_matrix[1][15][3] = 7; + + hayabusa_tree_matrix_is_init = true; +} + +static void set_channel_info(int c, int rank, struct ncclChannel *channel) +{ + channel->binTree.down[0] = hayabusa_tree_matrix[c%NUM_HAYABUSA_TREES][rank][0]; + channel->binTree.down[1] = hayabusa_tree_matrix[c%NUM_HAYABUSA_TREES][rank][1]; + channel->binTree.down[2] = hayabusa_tree_matrix[c%NUM_HAYABUSA_TREES][rank][2]; + channel->binTree.up = hayabusa_tree_matrix[c%NUM_HAYABUSA_TREES][rank][3]; +} + +ncclResult_t ncclBinaryTreeHayabusaPostset(struct ncclComm* comm, + struct ncclTopoGraph* treeGraph) { + int nChannels = comm->nChannels; + + hayabusa_tree_matrix_init(); + + for (int c=0; cchannels+c; + + set_channel_info(c, comm->localRank, channel); + } + + return ncclSuccess; +} ncclResult_t ncclTreeBasePostset(struct ncclComm* comm, struct ncclTopoGraph* treeGraph) { diff --git a/projects/rccl/src/include/graph.h b/projects/rccl/src/include/graph.h index 0de51f8406..aee2ad16b5 100644 --- a/projects/rccl/src/include/graph.h +++ b/projects/rccl/src/include/graph.h @@ -116,6 +116,7 @@ ncclResult_t ncclTopoPostset(struct ncclComm* comm, int* firstRanks, int* treePa ncclResult_t ncclTreeBasePostset(struct ncclComm* comm, struct ncclTopoGraph* treeGraph); ncclResult_t ncclBinaryTreePostset(struct ncclComm* comm, struct ncclTopoGraph* treeGraph); +ncclResult_t ncclBinaryTreeHayabusaPostset(struct ncclComm* comm, struct ncclTopoGraph* treeGraph); ncclResult_t ncclTopoTuneModel(struct ncclComm* comm, int minCompCap, int maxCompCap, struct ncclTopoGraph* treeGraph, struct ncclTopoGraph* ringGraph, struct ncclTopoGraph* collNetGraph); #include "info.h" diff --git a/projects/rccl/src/init.cc b/projects/rccl/src/init.cc index fee43665ce..a6796dca9b 100644 --- a/projects/rccl/src/init.cc +++ b/projects/rccl/src/init.cc @@ -1042,7 +1042,11 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm if (comm->topo->pivotA2ANumBiRings == 3) { NCCLCHECK(ncclTreeBasePostset(comm, &treeGraph)); - NCCLCHECK(ncclBinaryTreePostset(comm, &treeGraph)); + if (comm->virtualId == -1) { + NCCLCHECK(ncclBinaryTreeHayabusaPostset(comm, &treeGraph)); + } else { + NCCLCHECK(ncclBinaryTreePostset(comm, &treeGraph)); + } } free(allTopoRanks);