Add tilled version of puts and gets at wavefront level to the functional test suite

* Implemented tiled version of put*_wave and get*_wave functions
* Maintain single kernel that supports both tiled and untiled versions
* Disable IPC in the default RO build script


[ROCm/rocshmem commit: b6d31ac7ef]
Этот коммит содержится в:
avinashkethineedi
2024-09-07 16:06:36 -07:00
родитель 3d26792831
Коммит 9532e084fc
9 изменённых файлов: 78 добавлений и 82 удалений
+2 -1
Просмотреть файл
@@ -17,8 +17,9 @@ cmake \
-DDEBUG=OFF \
-DPROFILE=OFF \
-DUSE_GPU_IB=OFF \
-DUSE_RO=ON \
-DUSE_DC=OFF \
-DUSE_IPC=ON \
-DUSE_IPC=OFF \
-DUSE_THREADS=ON \
-DUSE_WF_COAL=OFF \
-DUSE_COHERENT_HEAP=ON \
+16 -4
Просмотреть файл
@@ -68,16 +68,16 @@ case $2 in
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 31 > $3/wg_putnbi_n2_w1_z1_1MB.log
check wg_putnbi_n2_w1_z1_1MB
echo "wg_get_tiled_n2_w1_z1_1MB"
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 28 -ti 1 > $3/wg_get_tiled_n2_w1_z1_1MB.log
ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 64 -s 1048576 -a 28 > $3/wg_get_tiled_n2_w1_z1_1MB.log
check wg_get_tiled_n2_w1_z1_1MB
echo "wg_getnbi_tiled_n2_w1_z1_1MB"
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 29 -ti 1 > $3/wg_getnbi_tiled_n2_w1_z1_1MB.log
ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 64 -s 1048576 -a 29 > $3/wg_getnbi_tiled_n2_w1_z1_1MB.log
check wg_getnbi_tiled_n2_w1_z1_1MB
echo "wg_put_tiled_n2_w1_z1_1MB"
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 30 -ti 1 > $3/wg_put_tiled_n2_w1_z1_1MB.log
ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 64 -s 1048576 -a 30 > $3/wg_put_tiled_n2_w1_z1_1MB.log
check wg_put_tiled_n2_w1_z1_1MB
echo "wg_putnbi_tiled_n2_w1_z1_1MB"
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 31 -ti 1 > $3/wg_putnbi_tiled_n2_w1_z1_1MB.log
ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 64 -s 1048576 -a 31 > $3/wg_putnbi_tiled_n2_w1_z1_1MB.log
check wg_putnbi_tiled_n2_w1_z1_1MB
echo "wave_get_n2_w1_z1_1MB"
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 32 > $3/wave_get_n2_w1_z1_1MB.log
@@ -91,6 +91,18 @@ case $2 in
echo "wave_putnbi_n2_w1_z1_1MB"
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 35 > $3/wave_putnbi_n2_w1_z1_1MB.log
check wave_putnbi_n2_w1_z1_1MB
echo "wave_get_tiled_n2_w1_z1_1MB"
ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 128 -s 1048576 -a 32 > $3/wave_get_tiled_n2_w1_z1_1MB.log
check wave_get_tiled_n2_w1_z1_1MB
echo "wave_getnbi_tiled_n2_w1_z1_1MB"
ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 128 -s 1048576 -a 33 > $3/wave_getnbi_tiled_n2_w1_z1_1MB.log
check wave_getnbi_tiled_n2_w1_z1_1MB
echo "wave_put_tiled_n2_w1_z1_1MB"
ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 128 -s 1048576 -a 34 > $3/wave_put_tiled_n2_w1_z1_1MB.log
check wave_put_tiled_n2_w1_z1_1MB
echo "wave_putnbi_tiled_n2_w1_z1_1MB"
ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 128 -s 1048576 -a 35 > $3/wave_putnbi_tiled_n2_w1_z1_1MB.log
check wave_putnbi_tiled_n2_w1_z1_1MB
echo "amofadd_n2_w1_z1"
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 1 -a 6 > $3/amofadd_n2_w1_z1.log
check amofadd_n2_w1_z1
+6 -54
Просмотреть файл
@@ -37,47 +37,6 @@ __global__ void ExtendedPrimitiveTest(int loop, int skip, uint64_t *timer,
roc_shmem_wg_init();
roc_shmem_wg_ctx_create(ctx_type, &ctx);
uint64_t start;
for (int i = 0; i < loop + skip; i++) {
if (i == skip) start = roc_shmem_timer();
switch (type) {
case WGGetTestType:
roc_shmemx_ctx_getmem_wg(ctx, r_buf, s_buf, size, 1);
break;
case WGGetNBITestType:
roc_shmemx_ctx_getmem_nbi_wg(ctx, r_buf, s_buf, size, 1);
break;
case WGPutTestType:
roc_shmemx_ctx_putmem_wg(ctx, r_buf, s_buf, size, 1);
break;
case WGPutNBITestType:
roc_shmemx_ctx_putmem_nbi_wg(ctx, r_buf, s_buf, size, 1);
break;
default:
break;
}
}
roc_shmem_ctx_quiet(ctx);
if (hipThreadIdx_x == 0) {
timer[hipBlockIdx_x] = roc_shmem_timer() - start;
}
roc_shmem_wg_ctx_destroy(&ctx);
roc_shmem_wg_finalize();
}
__global__ void ExtendedPrimitiveTestTiled(int loop, int skip, uint64_t *timer,
char *s_buf, char *r_buf, int size,
TestType type,
ShmemContextType ctx_type) {
__shared__ roc_shmem_ctx_t ctx;
roc_shmem_wg_init();
roc_shmem_wg_ctx_create(ctx_type, &ctx);
uint64_t start;
uint64_t idx = size * get_flat_grid_id();
s_buf += idx;
@@ -119,8 +78,8 @@ __global__ void ExtendedPrimitiveTestTiled(int loop, int skip, uint64_t *timer,
*****************************************************************************/
ExtendedPrimitiveTester::ExtendedPrimitiveTester(TesterArguments args)
: Tester(args) {
s_buf = (char *)roc_shmem_malloc(args.max_msg_size * args.wg_size);
r_buf = (char *)roc_shmem_malloc(args.max_msg_size * args.wg_size);
s_buf = (char *)roc_shmem_malloc(args.max_msg_size * args.num_wgs);
r_buf = (char *)roc_shmem_malloc(args.max_msg_size * args.num_wgs);
}
ExtendedPrimitiveTester::~ExtendedPrimitiveTester() {
@@ -129,24 +88,17 @@ ExtendedPrimitiveTester::~ExtendedPrimitiveTester() {
}
void ExtendedPrimitiveTester::resetBuffers(uint64_t size) {
memset(s_buf, '0', args.max_msg_size * args.wg_size);
memset(r_buf, '1', args.max_msg_size * args.wg_size);
memset(s_buf, '0', size * args.num_wgs);
memset(r_buf, '1', size * args.num_wgs);
}
void ExtendedPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize,
int loop, uint64_t size) {
size_t shared_bytes = 0;
if (args.tiled){
hipLaunchKernelGGL(ExtendedPrimitiveTestTiled, gridSize, blockSize, shared_bytes,
hipLaunchKernelGGL(ExtendedPrimitiveTest, gridSize, blockSize, shared_bytes,
stream, loop, args.skip, timer, s_buf, r_buf, size, _type,
_shmem_context);
}
else {
hipLaunchKernelGGL(ExtendedPrimitiveTest, gridSize, blockSize, shared_bytes,
stream, loop, args.skip, timer, s_buf, r_buf, size, _type,
_shmem_context);
}
num_msgs = (loop + args.skip) * gridSize.x;
num_timed_msgs = loop * gridSize.x;
@@ -158,7 +110,7 @@ void ExtendedPrimitiveTester::verifyResults(uint64_t size) {
: 1;
if (args.myid == check_id) {
for (int i = 0; i < size; i++) {
for (int i = 0; i < size * args.num_wgs; i++) {
if (r_buf[i] != '0') {
fprintf(stderr, "Data validation error at idx %d\n", i);
fprintf(stderr, "Got %c, Expected %c \n", r_buf[i], '0');
+33 -8
Просмотреть файл
@@ -57,6 +57,9 @@
Tester::Tester(TesterArguments args) : args(args) {
_type = (TestType)args.algorithm;
_shmem_context = args.shmem_context;
CHECK_HIP(hipGetDevice(&device_id));
CHECK_HIP(hipGetDeviceProperties(&deviceProps, device_id));
num_warps = args.wg_size / deviceProps.warpSize;
CHECK_HIP(hipStreamCreate(&stream));
CHECK_HIP(hipEventCreate(&start_event));
CHECK_HIP(hipEventCreate(&stop_event));
@@ -470,28 +473,32 @@ std::vector<Tester*> Tester::create(TesterArguments args) {
return testers;
case WGGetTestType:
if (rank == 0) {
if (args.tiled) std::cout << "Tiled Blocking WG level Gets***" << std::endl;
if (args.num_wgs > 1)
std::cout << "Tiled Blocking WG level Gets***" << std::endl;
else std::cout << "Blocking WG level Gets***" << std::endl;
}
testers.push_back(new ExtendedPrimitiveTester(args));
return testers;
case WGGetNBITestType:
if (rank == 0) {
if (args.tiled) std::cout << "Tiled Non-Blocking WG level Gets***" << std::endl;
if (args.num_wgs > 1)
std::cout << "Tiled Non-Blocking WG level Gets***" << std::endl;
else std::cout << "Non-Blocking WG level Gets***" << std::endl;
}
testers.push_back(new ExtendedPrimitiveTester(args));
return testers;
case WGPutTestType:
if (rank == 0) {
if (args.tiled) std::cout << "Tiled Blocking WG level Puts***" << std::endl;
if (args.num_wgs > 1)
std::cout << "Tiled Blocking WG level Puts***" << std::endl;
else std::cout << "Blocking WG level Puts***" << std::endl;
}
testers.push_back(new ExtendedPrimitiveTester(args));
return testers;
case WGPutNBITestType:
if (rank == 0) {
if(args.tiled) std::cout << "Tiled Non-Blocking WG level Puts***" << std::endl;
if (args.num_wgs > 1)
std::cout << "Tiled Non-Blocking WG level Puts***" << std::endl;
else std::cout << "Non-Blocking WG level Puts***" << std::endl;
}
testers.push_back(new ExtendedPrimitiveTester(args));
@@ -502,19 +509,35 @@ std::vector<Tester*> Tester::create(TesterArguments args) {
testers.push_back(new PrimitiveMRTester(args));
return testers;
case WAVEGetTestType:
if (rank == 0) std::cout << "WAVE Blocking Gets***" << std::endl;
if (rank == 0) {
if (args.num_wgs > 1 || args.wg_size / 64 > 1)
std::cout << "Tiled Blocking WAVE level Gets***" << std::endl;
else std::cout << "Blocking WAVE level Gets***" << std::endl;
}
testers.push_back(new WaveLevelPrimitiveTester(args));
return testers;
case WAVEGetNBITestType:
if (rank == 0) std::cout << "WAVE Non-Blocking Gets***" << std::endl;
if (rank == 0) {
if (args.num_wgs > 1 || args.wg_size / 64 > 1)
std::cout << "Tiled Non-Blocking WAVE level Gets***" << std::endl;
else std::cout << "Non-Blocking WAVE level Gets***" << std::endl;
}
testers.push_back(new WaveLevelPrimitiveTester(args));
return testers;
case WAVEPutTestType:
if (rank == 0) std::cout << "WAVE Blocking Puts***" << std::endl;
if (rank == 0) {
if (args.num_wgs > 1 || args.wg_size / 64 > 1)
std::cout << "Tiled Blocking WAVE level Puts***" << std::endl;
else std::cout << "Blocking WAVE level Puts***" << std::endl;
}
testers.push_back(new WaveLevelPrimitiveTester(args));
return testers;
case WAVEPutNBITestType:
if (rank == 0) std::cout << "WAVE Non-Blocking Puts***" << std::endl;
if (rank == 0) {
if (args.num_wgs > 1 || args.wg_size / 64 > 1)
std::cout << "Tiled Non-Blocking WAVE level Puts***" << std::endl;
else std::cout << "Non-Blocking WAVE level Puts***" << std::endl;
}
testers.push_back(new WaveLevelPrimitiveTester(args));
return testers;
default:
@@ -684,6 +707,8 @@ uint64_t Tester::gpuCyclesToMicroseconds(uint64_t cycles) {
uint64_t Tester::timerAvgInMicroseconds() {
uint64_t sum = 0;
//TODO: Modify the calcuation for the Tiled version of puts and gets at
// wavefront level (bpotter/avinash)
for (int i = 0; i < args.num_wgs; i++) {
sum += gpuCyclesToMicroseconds(timer[i]);
}
+3
Просмотреть файл
@@ -117,7 +117,9 @@ class Tester {
int num_msgs = 0;
int num_timed_msgs = 0;
int num_warps = 0;
int bw_factor = 1;
int device_id = 0;
TesterArguments args;
@@ -125,6 +127,7 @@ class Tester {
ShmemContextType _shmem_context = 8; // SHMEM_CTX_WP_PRIVATE
hipStream_t stream;
hipDeviceProp_t deviceProps;
uint64_t *timer = nullptr;
-4
Просмотреть файл
@@ -60,9 +60,6 @@ TesterArguments::TesterArguments(int argc, char *argv[]) {
} else if (arg == "-x") {
i++;
shmem_context = atoi(argv[i]);
} else if (arg == "-ti") {
i++;
tiled = (atoi(argv[i]) == 1);
} else {
show_usage(argv[0]);
exit(-1);
@@ -122,7 +119,6 @@ void TesterArguments::show_usage(std::string executable_name) {
std::cout << "\t-o <Operation type for the random_access test>\n";
std::cout << "\t-ta <Number of Thread Accessing the communication>\n";
std::cout << "\t-x <shmem context>\n";
std::cout << "\t-ti <Tiled version>\n";
}
void TesterArguments::get_rocshmem_arguments() {
-1
Просмотреть файл
@@ -58,7 +58,6 @@ class TesterArguments {
unsigned coal_coef = 64;
unsigned op_type = 0;
unsigned shmem_context = rocshmem::ROC_SHMEM_CTX_WG_PRIVATE;
bool tiled = false;
/**
* Arguments obtained from rocshmem
+17 -10
Просмотреть файл
@@ -27,17 +27,22 @@
using namespace rocshmem;
/******************************************************************************
* DEVICE TEST KERNEL
* DEVICE TEST KERNELS
*****************************************************************************/
__global__ void WaveLevelPrimitiveTest(int loop, int skip, uint64_t *timer,
char *s_buf, char *r_buf, int size,
TestType type,
ShmemContextType ctx_type) {
TestType type, ShmemContextType ctx_type,
int wf_size) {
__shared__ roc_shmem_ctx_t ctx;
roc_shmem_wg_init();
roc_shmem_wg_ctx_create(ctx_type, &ctx);
uint64_t start;
int wf_id = get_flat_block_id() / wf_size;
int offset = size * get_flat_grid_id() * (get_flat_block_size() / wf_size);
int idx = wf_id * size + offset;
s_buf += idx;
r_buf += idx;
for (int i = 0; i < loop + skip; i++) {
if (i == skip) start = roc_shmem_timer();
@@ -75,8 +80,10 @@ __global__ void WaveLevelPrimitiveTest(int loop, int skip, uint64_t *timer,
*****************************************************************************/
WaveLevelPrimitiveTester::WaveLevelPrimitiveTester(TesterArguments args)
: Tester(args) {
s_buf = (char *)roc_shmem_malloc(args.max_msg_size * args.wg_size);
r_buf = (char *)roc_shmem_malloc(args.max_msg_size * args.wg_size);
s_buf = (char *)roc_shmem_malloc(args.max_msg_size * args.num_wgs
* num_warps);
r_buf = (char *)roc_shmem_malloc(args.max_msg_size * args.num_wgs
* num_warps);
}
WaveLevelPrimitiveTester::~WaveLevelPrimitiveTester() {
@@ -85,8 +92,8 @@ WaveLevelPrimitiveTester::~WaveLevelPrimitiveTester() {
}
void WaveLevelPrimitiveTester::resetBuffers(uint64_t size) {
memset(s_buf, '0', args.max_msg_size * args.wg_size);
memset(r_buf, '1', args.max_msg_size * args.wg_size);
memset(s_buf, '0', size * args.num_wgs * num_warps);
memset(r_buf, '1', size * args.num_wgs * num_warps);
}
void WaveLevelPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize,
@@ -95,10 +102,10 @@ void WaveLevelPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize,
hipLaunchKernelGGL(WaveLevelPrimitiveTest, gridSize, blockSize, shared_bytes,
stream, loop, args.skip, timer, s_buf, r_buf, size, _type,
_shmem_context);
_shmem_context, deviceProps.warpSize);
num_msgs = (loop + args.skip) * gridSize.x;
num_timed_msgs = loop;
num_timed_msgs = loop * gridSize.x;
}
void WaveLevelPrimitiveTester::verifyResults(uint64_t size) {
@@ -107,7 +114,7 @@ void WaveLevelPrimitiveTester::verifyResults(uint64_t size) {
: 1;
if (args.myid == check_id) {
for (int i = 0; i < size; i++) {
for (int i = 0; i < size * args.num_wgs * num_warps; i++) {
if (r_buf[i] != '0') {
fprintf(stderr, "Data validation error at idx %d\n", i);
fprintf(stderr, "Got %c, Expected %c \n", r_buf[i], '0');
+1
Просмотреть файл
@@ -24,6 +24,7 @@
#define _WAVE_LEVEL_PRIMITIVE_TEST_HPP_
#include "tester.hpp"
#include "../src/util.hpp"
/******************************************************************************
* HOST TESTER CLASS