diff --git a/projects/rccl/tools/rccl-prim-test/rccl_prim_test.cpp b/projects/rccl/tools/rccl-prim-test/rccl_prim_test.cpp index 18463b07e5..d328dbf5c8 100644 --- a/projects/rccl/tools/rccl-prim-test/rccl_prim_test.cpp +++ b/projects/rccl/tools/rccl-prim-test/rccl_prim_test.cpp @@ -218,10 +218,11 @@ do { \ } \ } while (0) -static void setupPeers(uint32_t *info, bool* is_xgmi, bool* is_2h4p) { +static void setupPeers(uint32_t *info, bool* is_xgmi) { int deviceCnt, dev; - *is_xgmi = *is_2h4p = 0; + // is_xgmi indicates all link are one hop XGMI + *is_xgmi = 1; HIPCHECK(hipGetDeviceCount(&deviceCnt)); HIPCHECK(hipGetDevice(&dev)); //! If gpus are not peer enabled, enable them @@ -239,20 +240,65 @@ static void setupPeers(uint32_t *info, bool* is_xgmi, bool* is_2h4p) { HIPCHECK(hipDeviceEnablePeerAccess(j, 0)); uint32_t linktype; HIPCHECK(hipExtGetLinkTypeAndHopCount(i, j, &linktype, &info[i*deviceCnt+j])); - if (*is_xgmi == 0 && linktype == 4) *is_xgmi = 1; + if (linktype != 4 || info[i*deviceCnt+j] != 1) *is_xgmi = 0; } else info[i*deviceCnt+j] = 0; } } - if (*is_xgmi && deviceCnt == 8) { - uint32_t linktype, hop; - HIPCHECK(hipExtGetLinkTypeAndHopCount(0, 4, &linktype, &hop)); - if (linktype != 4) *is_2h4p = 1; - } HIPCHECK(hipSetDevice(dev)); } +static void parseChordalRing(char **str) { + static const char *ringBase = "0 6 7 4 5 3 2 1|0 5 6 3 7 1 4 2|0 4 6 2 7 5 1 3|0 1 2 3 5 4 7 6|0 2 4 1 7 3 6 5|0 3 1 5 7 2 6 4"; + static char ringRemap[256]; + int id[8], dist[8]; + int i; + + int ngpus; + HIPCHECK(hipGetDeviceCount(&ngpus)); + // single node CR8G only + if (ngpus != 8) + return; + // validate chordal ring and calculate distance + for (i=0; i ngpus-1) { + return; + } + dist[i] = sum; + } + // remap GPU ids + for (i = 0; i= '0' && ringBase[i] <= '9') + ringRemap[i] = id[ringBase[i]-'0']+'0'; + else + ringRemap[i] = ringBase[i]; + } + ringRemap[i] = 0; + *str = ringRemap; + return; +} + static void printRing(int id, int *ring, int deviceCnt) { printf("Ring %d: ", id); for (int i = 0; i < deviceCnt; i++) @@ -370,16 +416,13 @@ int main(int argc,char* argv[]) uint32_t connection_info[MAX_GPU*MAX_GPU]; // Enable peer access - bool is_xgmi, is_2h4p; - setupPeers(connection_info, &is_xgmi, &is_2h4p); - hipDeviceProp_t prop; - HIPCHECK(hipGetDeviceProperties(&prop, 0)); + bool is_xgmi; + char *cr8g = 0; static const char *ring_4p3l = "0 1 2 3|0 1 3 2|0 2 1 3|0 2 3 1|0 3 1 2|0 3 2 1"; - static const char *ring_8p6l = "0 4 5 6 7 3 2 1|0 7 4 3 5 1 6 2|0 6 4 2 5 7 1 3|0 1 2 3 7 6 5 4|0 2 6 1 5 3 4 7|0 3 1 7 5 2 4 6"; - if (prop.gcnArch == 908) { - if (nGpu == 4 && is_xgmi) r = (char *)ring_4p3l; - if (nGpu == 8 && is_xgmi && !is_2h4p) r = (char *)ring_8p6l; - } + setupPeers(connection_info, &is_xgmi); + parseChordalRing(&cr8g); + if (nGpu == 4 && is_xgmi) r = (char *)ring_4p3l; + if (nGpu == 8 && cr8g) r = (char *)cr8g; // clockwise and counter clockwise rings int ring[MAX_WORKGROUPS][MAX_GPU];