Add tilled version of puts and gets at the workgroup level to the functional test suite

[ROCm/rocshmem commit: d226922733]
This commit is contained in:
avinashkethineedi
2024-09-04 12:55:10 -07:00
parent 31c43bd58d
commit 3d26792831
6 changed files with 89 additions and 9 deletions
@@ -67,6 +67,18 @@ case $2 in
echo "wg_putnbi_n2_w1_z1_1MB"
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 31 > $3/wg_putnbi_n2_w1_z1_1MB.log
check wg_putnbi_n2_w1_z1_1MB
echo "wg_get_tiled_n2_w1_z1_1MB"
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 28 -ti 1 > $3/wg_get_tiled_n2_w1_z1_1MB.log
check wg_get_tiled_n2_w1_z1_1MB
echo "wg_getnbi_tiled_n2_w1_z1_1MB"
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 29 -ti 1 > $3/wg_getnbi_tiled_n2_w1_z1_1MB.log
check wg_getnbi_tiled_n2_w1_z1_1MB
echo "wg_put_tiled_n2_w1_z1_1MB"
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 30 -ti 1 > $3/wg_put_tiled_n2_w1_z1_1MB.log
check wg_put_tiled_n2_w1_z1_1MB
echo "wg_putnbi_tiled_n2_w1_z1_1MB"
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 31 -ti 1 > $3/wg_putnbi_tiled_n2_w1_z1_1MB.log
check wg_putnbi_tiled_n2_w1_z1_1MB
echo "wave_get_n2_w1_z1_1MB"
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 32 > $3/wave_get_n2_w1_z1_1MB.log
check wave_get_n2_w1_z1_1MB
@@ -27,7 +27,7 @@
using namespace rocshmem;
/******************************************************************************
* DEVICE TEST KERNEL
* DEVICE TEST KERNELS
*****************************************************************************/
__global__ void ExtendedPrimitiveTest(int loop, int skip, uint64_t *timer,
char *s_buf, char *r_buf, int size,
@@ -70,6 +70,50 @@ __global__ void ExtendedPrimitiveTest(int loop, int skip, uint64_t *timer,
roc_shmem_wg_finalize();
}
__global__ void ExtendedPrimitiveTestTiled(int loop, int skip, uint64_t *timer,
char *s_buf, char *r_buf, int size,
TestType type,
ShmemContextType ctx_type) {
__shared__ roc_shmem_ctx_t ctx;
roc_shmem_wg_init();
roc_shmem_wg_ctx_create(ctx_type, &ctx);
uint64_t start;
uint64_t idx = size * get_flat_grid_id();
s_buf += idx;
r_buf += idx;
for (int i = 0; i < loop + skip; i++) {
if (i == skip) start = roc_shmem_timer();
switch (type) {
case WGGetTestType:
roc_shmemx_ctx_getmem_wg(ctx, r_buf, s_buf, size, 1);
break;
case WGGetNBITestType:
roc_shmemx_ctx_getmem_nbi_wg(ctx, r_buf, s_buf, size, 1);
break;
case WGPutTestType:
roc_shmemx_ctx_putmem_wg(ctx, r_buf, s_buf, size, 1);
break;
case WGPutNBITestType:
roc_shmemx_ctx_putmem_nbi_wg(ctx, r_buf, s_buf, size, 1);
break;
default:
break;
}
}
roc_shmem_ctx_quiet(ctx);
if (hipThreadIdx_x == 0) {
timer[hipBlockIdx_x] = roc_shmem_timer() - start;
}
roc_shmem_wg_ctx_destroy(&ctx);
roc_shmem_wg_finalize();
}
/******************************************************************************
* HOST TESTER CLASS METHODS
*****************************************************************************/
@@ -93,17 +137,23 @@ void ExtendedPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize,
int loop, uint64_t size) {
size_t shared_bytes = 0;
hipLaunchKernelGGL(ExtendedPrimitiveTest, gridSize, blockSize, shared_bytes,
if (args.tiled){
hipLaunchKernelGGL(ExtendedPrimitiveTestTiled, gridSize, blockSize, shared_bytes,
stream, loop, args.skip, timer, s_buf, r_buf, size, _type,
_shmem_context);
}
else {
hipLaunchKernelGGL(ExtendedPrimitiveTest, gridSize, blockSize, shared_bytes,
stream, loop, args.skip, timer, s_buf, r_buf, size, _type,
_shmem_context);
}
num_msgs = (loop + args.skip) * gridSize.x;
num_timed_msgs = loop;
num_timed_msgs = loop * gridSize.x;
}
void ExtendedPrimitiveTester::verifyResults(uint64_t size) {
int check_id = (_type == WGGetTestType || _type == WGGetNBITestType ||
_type == WAVEGetTestType)
int check_id = (_type == WGGetTestType || _type == WGGetNBITestType)
? 0
: 1;
@@ -24,6 +24,7 @@
#define _EXTENDED_PRIMITIVES_HPP_
#include "tester.hpp"
#include "../src/util.hpp"
/******************************************************************************
* HOST TESTER CLASS
@@ -469,19 +469,31 @@ std::vector<Tester*> Tester::create(TesterArguments args) {
testers.push_back(new ShmemPtrTester(args));
return testers;
case WGGetTestType:
if (rank == 0) std::cout << "Blocking WG level Gets***" << std::endl;
if (rank == 0) {
if (args.tiled) std::cout << "Tiled Blocking WG level Gets***" << std::endl;
else std::cout << "Blocking WG level Gets***" << std::endl;
}
testers.push_back(new ExtendedPrimitiveTester(args));
return testers;
case WGGetNBITestType:
if (rank == 0) std::cout << "Non-Blocking WG level Gets***" << std::endl;
if (rank == 0) {
if (args.tiled) std::cout << "Tiled Non-Blocking WG level Gets***" << std::endl;
else std::cout << "Non-Blocking WG level Gets***" << std::endl;
}
testers.push_back(new ExtendedPrimitiveTester(args));
return testers;
case WGPutTestType:
if (rank == 0) std::cout << "Blocking WG level Puts***" << std::endl;
if (rank == 0) {
if (args.tiled) std::cout << "Tiled Blocking WG level Puts***" << std::endl;
else std::cout << "Blocking WG level Puts***" << std::endl;
}
testers.push_back(new ExtendedPrimitiveTester(args));
return testers;
case WGPutNBITestType:
if (rank == 0) std::cout << "Non-Blocking WG level Puts***" << std::endl;
if (rank == 0) {
if(args.tiled) std::cout << "Tiled Non-Blocking WG level Puts***" << std::endl;
else std::cout << "Non-Blocking WG level Puts***" << std::endl;
}
testers.push_back(new ExtendedPrimitiveTester(args));
return testers;
case PutNBIMRTestType:
@@ -60,6 +60,9 @@ TesterArguments::TesterArguments(int argc, char *argv[]) {
} else if (arg == "-x") {
i++;
shmem_context = atoi(argv[i]);
} else if (arg == "-ti") {
i++;
tiled = (atoi(argv[i]) == 1);
} else {
show_usage(argv[0]);
exit(-1);
@@ -119,6 +122,7 @@ void TesterArguments::show_usage(std::string executable_name) {
std::cout << "\t-o <Operation type for the random_access test>\n";
std::cout << "\t-ta <Number of Thread Accessing the communication>\n";
std::cout << "\t-x <shmem context>\n";
std::cout << "\t-ti <Tiled version>\n";
}
void TesterArguments::get_rocshmem_arguments() {
@@ -58,6 +58,7 @@ class TesterArguments {
unsigned coal_coef = 64;
unsigned op_type = 0;
unsigned shmem_context = rocshmem::ROC_SHMEM_CTX_WG_PRIVATE;
bool tiled = false;
/**
* Arguments obtained from rocshmem