From 2f99c7bbad1e102152ef6427d2f67ece3cd58a17 Mon Sep 17 00:00:00 2001 From: Wenkai Du <43822138+wenkaidu@users.noreply.github.com> Date: Wed, 1 Jul 2020 15:11:02 -0700 Subject: [PATCH] topo_expl: each rank needs to have its own memory for graphs (#225) [ROCm/rccl commit: d3548cc474c3878e31bad2e3ff9258c1f7afe5b6] --- projects/rccl/tools/topo_expl/topo_expl.cpp | 24 +++++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/projects/rccl/tools/topo_expl/topo_expl.cpp b/projects/rccl/tools/topo_expl/topo_expl.cpp index 555bb28225..1f87fa0646 100644 --- a/projects/rccl/tools/topo_expl/topo_expl.cpp +++ b/projects/rccl/tools/topo_expl/topo_expl.cpp @@ -209,20 +209,30 @@ int main(int argc,char* argv[]) bootstrapAllGather(&comm[i], allGather1Data); } - struct ncclTopoGraph treeGraph, ringGraph, collNetGraph; - - for (int i = 0; i < nranks; i++) { - node_model = network.GetNode(i); - assert(node_model!=0); - initTransportsRank_1(&comm[i], allGather1Data, allGather3Data, treeGraph, ringGraph, collNetGraph); + struct ncclTopoGraph *treeGraph, *ringGraph, *collNetGraph; + treeGraph = (struct ncclTopoGraph *)malloc(sizeof(struct ncclTopoGraph)*nranks); + ringGraph = (struct ncclTopoGraph *)malloc(sizeof(struct ncclTopoGraph)*nranks); + collNetGraph = (struct ncclTopoGraph *)malloc(sizeof(struct ncclTopoGraph)*nranks); + if (!treeGraph || !ringGraph || !collNetGraph) { + printf("Failed to allocate memory for graphs\n"); + return -1; } for (int i = 0; i < nranks; i++) { node_model = network.GetNode(i); assert(node_model!=0); - initTransportsRank_3(&comm[i], allGather3Data, treeGraph, ringGraph, collNetGraph); + initTransportsRank_1(&comm[i], allGather1Data, allGather3Data, treeGraph[i], ringGraph[i], collNetGraph[i]); } + for (int i = 0; i < nranks; i++) { + node_model = network.GetNode(i); + assert(node_model!=0); + initTransportsRank_3(&comm[i], allGather3Data, treeGraph[i], ringGraph[i], collNetGraph[i]); + } + + free(treeGraph); + free(ringGraph); + free(collNetGraph); free(allGather3Data); free(allGather1Data);