Merge pull request #1211 from saurabhAMD/channel

enable UT to test with channels greater than 64

[ROCm/rccl commit: 959545dce2]
Этот коммит содержится в:
saurabhAMD
2024-06-13 14:38:38 -05:00
коммит произвёл GitHub
родитель f7fb3392fb 44064a612c
Коммит 09c4d50e50
4 изменённых файлов: 93 добавлений и 0 удалений
+24
Просмотреть файл
@@ -102,6 +102,30 @@ namespace RcclUnitTesting
testBed.Finalize();
}
TEST(AllReduce, Channels)
{
TestBed testBed;
if(testBed.ev.isGfx94) {
// Configuration
std::vector<ncclFunc_t> const funcTypes = {ncclCollAllReduce};
std::vector<ncclDataType_t> const dataTypes = {ncclBfloat16, ncclHalf};
std::vector<ncclRedOp_t> const redOps = {ncclSum};
std::vector<int> const roots = {0};
std::vector<int> const numElements = {64 * 1024 * 1024, 1024};
std::vector<bool> const inPlaceList = {false};
std::vector<bool> const managedMemList = {false};
std::vector<bool> const useHipGraphList = {false, true};
std::vector<char *> const channelList = {"56", "84", "112"};
for (auto channel : channelList) {
setenv("NCCL_MIN_NCHANNELS", channel, 1);
testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements,
inPlaceList, managedMemList, useHipGraphList);
testBed.Finalize();
unsetenv("NCCL_MIN_NCHANNELS");
}
}
}
TEST(AllReduce, ManagedMemGraph)
{
TestBed testBed;
+24
Просмотреть файл
@@ -85,4 +85,28 @@ namespace RcclUnitTesting
inPlaceList, managedMemList, useHipGraphList);
testBed.Finalize();
}
TEST(AllToAll, Channels)
{
TestBed testBed;
if(testBed.ev.isGfx94) {
// Configuration
std::vector<ncclFunc_t> const funcTypes = {ncclCollAllToAll};
std::vector<ncclDataType_t> const dataTypes = {ncclBfloat16, ncclHalf};
std::vector<ncclRedOp_t> const redOps = {ncclSum};
std::vector<int> const roots = {0};
std::vector<int> const numElements = {64 * 1024 * 1024, 1024};
std::vector<bool> const inPlaceList = {false};
std::vector<bool> const managedMemList = {false};
std::vector<bool> const useHipGraphList = {false, true};
std::vector<char *> const channelList = {"56", "84", "112"};
for (auto channel : channelList) {
setenv("NCCL_MIN_NCHANNELS", channel, 1);
testBed.RunSimpleSweep(funcTypes, dataTypes, redOps, roots, numElements,
inPlaceList, managedMemList, useHipGraphList);
testBed.Finalize();
unsetenv("NCCL_MIN_NCHANNELS");
}
}
}
}
+44
Просмотреть файл
@@ -15,6 +15,48 @@ namespace RcclUnitTesting
int const UT_SINGLE_PROCESS = (1<<0);
int const UT_MULTI_PROCESS = (1<<1);
int getArchInfo(bool *isRightArch)
{
// Prepare parent->child pipe
int pipefd[2];
if (pipe(pipefd) == -1) {
ERROR("Unable to create parent->child pipe for getting number of devices\n");
return TEST_FAIL;
}
pid_t pid = fork();
if (0 == pid) {
bool isGfx94 = false;
int dev;
hipGetDeviceCount(&dev);
for (int deviceId = 0; deviceId < dev; deviceId++) {
char gcn[256];
hipDeviceProp_t devProp;
hipGetDeviceProperties(&devProp, deviceId);
char *gcnArchNameToken = strtok(devProp.gcnArchName, ":");
strcpy(gcn, gcnArchNameToken);
if(std::strncmp("gfx94", gcn, 5) == 0) {
isGfx94 = true;
} else {
isGfx94 = false;
break;
}
}
if (write(pipefd[1], &isGfx94, sizeof(isGfx94)) != sizeof(isGfx94)) return TEST_FAIL;
close(pipefd[0]);
close(pipefd[1]);
exit(EXIT_SUCCESS);
}
else {
int status;
if (read(pipefd[0], isRightArch, sizeof(*isRightArch)) != sizeof(*isRightArch)) return TEST_FAIL;
waitpid(pid, &status, 0);
assert(!status);
close(pipefd[0]);
close(pipefd[1]);
}
return TEST_SUCCESS;
}
int getDeviceCount(int *devices)
{
// Prepare parent->child pipe
@@ -52,6 +94,8 @@ namespace RcclUnitTesting
// NOTE: Cannot use HIP call prior to launching unless it is inside another child process
numDetectedGpus = 0;
getDeviceCount(&numDetectedGpus);
isGfx94 = false;
getArchInfo(&isGfx94);
showNames = GetEnvVar("UT_SHOW_NAMES" , 1);
minGpus = GetEnvVar("UT_MIN_GPUS" , 2);
+1
Просмотреть файл
@@ -29,6 +29,7 @@ namespace RcclUnitTesting
bool showTiming; // Show timing per case at end [UT_SHOW_TIMING]
bool useInteractive; // Run in interactive mode [UT_INTERACTIVE]
int timeoutUs; // Set timeout for child in microseconds [UT_TIMEOUT_US]
bool isGfx94; // Detects if architecture is gfx94
// Constructor that parses and collects environment variables
EnvVars();