Fix incorrect net counting (#339)

* Fix incorrect net counting

* Add comments
This commit is contained in:
Wenkai Du
2021-04-05 12:21:57 -07:00
committato da GitHub
parent 1d2946ee4b
commit 17491c918e
2 ha cambiato i file con 4 aggiunte e 2 eliminazioni
+2 -1
Vedi File
@@ -1012,7 +1012,8 @@ static ncclResult_t initTransportsRank(struct ncclComm* comm, ncclUniqueId* comm
// count NETs used by ring
int nNets = 0;
int nets[MAXCHANNELS*2];
for (int i = 0; i < ringGraph.nChannels; i++) {
// do not count NETs in case of single node, i.e comm->topo->nodes[GPU].count == comm->topo->nRanks
for (int i = 0; comm->topo->nodes[GPU].count != comm->topo->nRanks && i < ringGraph.nChannels; i++) {
for (int j = 0; j < 2; j++) {
int k;
for (k = 0; k < nNets; k++)
+2 -1
Vedi File
@@ -604,7 +604,8 @@ ncclResult_t initTransportsRank_3(struct ncclComm* comm, struct allGather3Data_t
// count NETs used by ring
int nNets = 0;
int nets[MAXCHANNELS*2];
for (int i = 0; i < ringGraph.nChannels; i++) {
// do not count NETs in case of single node, i.e comm->topo->nodes[GPU].count == comm->topo->nRanks
for (int i = 0; comm->topo->nodes[GPU].count != comm->topo->nRanks && i < ringGraph.nChannels; i++) {
for (int j = 0; j < 2; j++) {
int k;
for (k = 0; k < nNets; k++)