Merge pull request #188 from wenkaidu/prim_test
rccl-prim-test: auto-detect rings in 4P and 8P configurations
[ROCm/rccl commit: 3ac98e7d39]
Esse commit está contido em:
@@ -218,10 +218,11 @@ do { \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
static void setupPeers(uint32_t *info, bool* is_xgmi, bool* is_2h4p) {
|
||||
static void setupPeers(uint32_t *info, bool* is_xgmi) {
|
||||
int deviceCnt, dev;
|
||||
|
||||
*is_xgmi = *is_2h4p = 0;
|
||||
// is_xgmi indicates all link are one hop XGMI
|
||||
*is_xgmi = 1;
|
||||
HIPCHECK(hipGetDeviceCount(&deviceCnt));
|
||||
HIPCHECK(hipGetDevice(&dev));
|
||||
//! If gpus are not peer enabled, enable them
|
||||
@@ -239,20 +240,65 @@ static void setupPeers(uint32_t *info, bool* is_xgmi, bool* is_2h4p) {
|
||||
HIPCHECK(hipDeviceEnablePeerAccess(j, 0));
|
||||
uint32_t linktype;
|
||||
HIPCHECK(hipExtGetLinkTypeAndHopCount(i, j, &linktype, &info[i*deviceCnt+j]));
|
||||
if (*is_xgmi == 0 && linktype == 4) *is_xgmi = 1;
|
||||
if (linktype != 4 || info[i*deviceCnt+j] != 1) *is_xgmi = 0;
|
||||
}
|
||||
else
|
||||
info[i*deviceCnt+j] = 0;
|
||||
}
|
||||
}
|
||||
if (*is_xgmi && deviceCnt == 8) {
|
||||
uint32_t linktype, hop;
|
||||
HIPCHECK(hipExtGetLinkTypeAndHopCount(0, 4, &linktype, &hop));
|
||||
if (linktype != 4) *is_2h4p = 1;
|
||||
}
|
||||
HIPCHECK(hipSetDevice(dev));
|
||||
}
|
||||
|
||||
static void parseChordalRing(char **str) {
|
||||
static const char *ringBase = "0 6 7 4 5 3 2 1|0 5 6 3 7 1 4 2|0 4 6 2 7 5 1 3|0 1 2 3 5 4 7 6|0 2 4 1 7 3 6 5|0 3 1 5 7 2 6 4";
|
||||
static char ringRemap[256];
|
||||
int id[8], dist[8];
|
||||
int i;
|
||||
|
||||
int ngpus;
|
||||
HIPCHECK(hipGetDeviceCount(&ngpus));
|
||||
// single node CR8G only
|
||||
if (ngpus != 8)
|
||||
return;
|
||||
// validate chordal ring and calculate distance
|
||||
for (i=0; i<ngpus; i++) {
|
||||
int sum = ngpus*(ngpus-1)/2 - i;
|
||||
int count = 0;
|
||||
for (int n = 0; n<ngpus; n++) {
|
||||
uint32_t linktype, hop;
|
||||
HIPCHECK(hipExtGetLinkTypeAndHopCount(i, n, &linktype, &hop));
|
||||
if (linktype != 4 || hop != 1) continue;
|
||||
sum -= n;
|
||||
count ++;
|
||||
}
|
||||
if(count != ngpus-2 || sum < 0 || sum > ngpus-1) {
|
||||
return;
|
||||
}
|
||||
dist[i] = sum;
|
||||
}
|
||||
// remap GPU ids
|
||||
for (i = 0; i<ngpus; i++) id[i] = i;
|
||||
for (i = 0; i<ngpus; i++) {
|
||||
if (dist[i] == ngpus-1-i) continue;
|
||||
int j, m, n, temp;
|
||||
for (j=i+1; j < ngpus; j++)
|
||||
if(dist[j] == ngpus-1-i) break;
|
||||
m = dist[i]; n = dist[j]; dist[i] = n; dist[j] = m;
|
||||
temp = id[m]; id[m] = id[n]; id[n] = temp; temp =dist[m];
|
||||
dist[m] = dist[n]; dist[n] = temp;
|
||||
}
|
||||
// create chordal ring based on reference and remapped ids
|
||||
for (i = 0; i <strlen(ringBase); i++) {
|
||||
if (ringBase[i] >= '0' && ringBase[i] <= '9')
|
||||
ringRemap[i] = id[ringBase[i]-'0']+'0';
|
||||
else
|
||||
ringRemap[i] = ringBase[i];
|
||||
}
|
||||
ringRemap[i] = 0;
|
||||
*str = ringRemap;
|
||||
return;
|
||||
}
|
||||
|
||||
static void printRing(int id, int *ring, int deviceCnt) {
|
||||
printf("Ring %d: ", id);
|
||||
for (int i = 0; i < deviceCnt; i++)
|
||||
@@ -370,16 +416,13 @@ int main(int argc,char* argv[])
|
||||
|
||||
uint32_t connection_info[MAX_GPU*MAX_GPU];
|
||||
// Enable peer access
|
||||
bool is_xgmi, is_2h4p;
|
||||
setupPeers(connection_info, &is_xgmi, &is_2h4p);
|
||||
hipDeviceProp_t prop;
|
||||
HIPCHECK(hipGetDeviceProperties(&prop, 0));
|
||||
bool is_xgmi;
|
||||
char *cr8g = 0;
|
||||
static const char *ring_4p3l = "0 1 2 3|0 1 3 2|0 2 1 3|0 2 3 1|0 3 1 2|0 3 2 1";
|
||||
static const char *ring_8p6l = "0 4 5 6 7 3 2 1|0 7 4 3 5 1 6 2|0 6 4 2 5 7 1 3|0 1 2 3 7 6 5 4|0 2 6 1 5 3 4 7|0 3 1 7 5 2 4 6";
|
||||
if (prop.gcnArch == 908) {
|
||||
if (nGpu == 4 && is_xgmi) r = (char *)ring_4p3l;
|
||||
if (nGpu == 8 && is_xgmi && !is_2h4p) r = (char *)ring_8p6l;
|
||||
}
|
||||
setupPeers(connection_info, &is_xgmi);
|
||||
parseChordalRing(&cr8g);
|
||||
if (nGpu == 4 && is_xgmi) r = (char *)ring_4p3l;
|
||||
if (nGpu == 8 && cr8g) r = (char *)cr8g;
|
||||
|
||||
// clockwise and counter clockwise rings
|
||||
int ring[MAX_WORKGROUPS][MAX_GPU];
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário