diff --git a/projects/rccl/tools/topo_expl/topo_expl.cpp b/projects/rccl/tools/topo_expl/topo_expl.cpp index 555bb28225..1f87fa0646 100644 --- a/projects/rccl/tools/topo_expl/topo_expl.cpp +++ b/projects/rccl/tools/topo_expl/topo_expl.cpp @@ -209,20 +209,30 @@ int main(int argc,char* argv[]) bootstrapAllGather(&comm[i], allGather1Data); } - struct ncclTopoGraph treeGraph, ringGraph, collNetGraph; - - for (int i = 0; i < nranks; i++) { - node_model = network.GetNode(i); - assert(node_model!=0); - initTransportsRank_1(&comm[i], allGather1Data, allGather3Data, treeGraph, ringGraph, collNetGraph); + struct ncclTopoGraph *treeGraph, *ringGraph, *collNetGraph; + treeGraph = (struct ncclTopoGraph *)malloc(sizeof(struct ncclTopoGraph)*nranks); + ringGraph = (struct ncclTopoGraph *)malloc(sizeof(struct ncclTopoGraph)*nranks); + collNetGraph = (struct ncclTopoGraph *)malloc(sizeof(struct ncclTopoGraph)*nranks); + if (!treeGraph || !ringGraph || !collNetGraph) { + printf("Failed to allocate memory for graphs\n"); + return -1; } for (int i = 0; i < nranks; i++) { node_model = network.GetNode(i); assert(node_model!=0); - initTransportsRank_3(&comm[i], allGather3Data, treeGraph, ringGraph, collNetGraph); + initTransportsRank_1(&comm[i], allGather1Data, allGather3Data, treeGraph[i], ringGraph[i], collNetGraph[i]); } + for (int i = 0; i < nranks; i++) { + node_model = network.GetNode(i); + assert(node_model!=0); + initTransportsRank_3(&comm[i], allGather3Data, treeGraph[i], ringGraph[i], collNetGraph[i]); + } + + free(treeGraph); + free(ringGraph); + free(collNetGraph); free(allGather3Data); free(allGather1Data);