Update MP UT to support arbitrary # of GPUs; multiple bugfixes (#16)

* Fixing temp file creation/deletion for Clique kernel mode.

* Refactoring of MP unit tests; include bugfixes and general support for any number of GPUs

* GroupCall MP UT properly quits when too many devices specified

* MP UT will programmatically set NCCL_COMM_ID if not specified; updated install script
이 커밋은 다음에 포함됨:
Stanley Tsang
2021-02-05 17:49:25 -07:00
커밋한 사람 GitHub
부모 6dfdfef98f
커밋 d00b7d17bd
23개의 변경된 파일538개의 추가작업 그리고 716개의 파일을 삭제
+13 -49
파일 보기
@@ -10,63 +10,27 @@ namespace CorrectnessTests
{
TEST_P(GatherMultiProcessCorrectnessTest, Correctness)
{
Dataset* dataset = (Dataset*)mmap(NULL, sizeof(Dataset), PROT_READ|PROT_WRITE, MAP_SHARED|MAP_ANONYMOUS, -1, 0);
dataset->InitializeRootProcess(numDevices, numElements, dataType, inPlace, ncclCollGather);
Barrier::ClearShmFiles(std::atoi(getenv("NCCL_COMM_ID")));
std::vector<int> pids(numDevices);
int pid1 = 0;
int pid2 = 0;
int pid3 = 0;
pid1 = fork();
// From this point on, ignore original process as we cannot have it create a HIP context
if (pid1 == 0)
int gpu = -1;
for (int i = 0; i < numDevices; i++)
{
pid2 = fork();
if (numDevices > 2)
gpu++;
int pid = fork();
if (pid == 0)
{
pid3 = fork();
}
if ((pid2 > 0 && pid3 == 0 && numDevices == 2) || (pid2 > 0 && pid3 > 0 && numDevices > 2))
{
// Process 0
TestGather(0, *dataset);
if (pid3 > 0)
{
waitpid(pid3, NULL, 0);
}
}
else if ((pid2 == 0 && pid3 == 0 && numDevices == 2) || (pid2 == 0 && pid3 > 0 && numDevices > 2))
{
// Process 1
TestGather(1, *dataset);
if (numDevices > 2)
{
waitpid(pid3, NULL, 0);
}
exit(0);
}
else if (pid2 > 0 && pid3 == 0 && numDevices > 2)
{
// Process 2 (available when numDevices > 2)
TestGather(2, *dataset);
exit(0);
}
else if (pid2 == 0 && pid3 == 0 && numDevices == 4)
{
// Process 3 (available when numDevices == 4)
TestGather(3, *dataset);
exit(0);
bool pass;
TestGather(gpu, *dataset, pass);
TerminateChildProcess(pass);
}
else
{
exit(0);
pids[gpu] = pid;
}
waitpid(pid2, NULL, 0);
exit(0);
}
waitpid(pid1, NULL, 0);
munmap(dataset, sizeof(Dataset));
ValidateProcesses(pids);
}
INSTANTIATE_TEST_SUITE_P(GatherMultiProcessCorrectnessSweep,
@@ -88,7 +52,7 @@ namespace CorrectnessTests
// Number of elements
testing::Values(1024, 1048576),
// Number of devices
testing::Values(2,3,4),
testing::Values(2,3,4,8),
// In-place or not
testing::Values(false),
testing::Values("RCCL_ALLTOALL_KERNEL_DISABLE=0", "RCCL_ALLTOALL_KERNEL_DISABLE=1")),