Add tilled version of puts and gets at the workgroup level to the functional test suite
[ROCm/rocshmem commit: d226922733]
This commit is contained in:
@@ -67,6 +67,18 @@ case $2 in
|
||||
echo "wg_putnbi_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 31 > $3/wg_putnbi_n2_w1_z1_1MB.log
|
||||
check wg_putnbi_n2_w1_z1_1MB
|
||||
echo "wg_get_tiled_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 28 -ti 1 > $3/wg_get_tiled_n2_w1_z1_1MB.log
|
||||
check wg_get_tiled_n2_w1_z1_1MB
|
||||
echo "wg_getnbi_tiled_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 29 -ti 1 > $3/wg_getnbi_tiled_n2_w1_z1_1MB.log
|
||||
check wg_getnbi_tiled_n2_w1_z1_1MB
|
||||
echo "wg_put_tiled_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 30 -ti 1 > $3/wg_put_tiled_n2_w1_z1_1MB.log
|
||||
check wg_put_tiled_n2_w1_z1_1MB
|
||||
echo "wg_putnbi_tiled_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 31 -ti 1 > $3/wg_putnbi_tiled_n2_w1_z1_1MB.log
|
||||
check wg_putnbi_tiled_n2_w1_z1_1MB
|
||||
echo "wave_get_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 32 > $3/wave_get_n2_w1_z1_1MB.log
|
||||
check wave_get_n2_w1_z1_1MB
|
||||
|
||||
@@ -27,7 +27,7 @@
|
||||
using namespace rocshmem;
|
||||
|
||||
/******************************************************************************
|
||||
* DEVICE TEST KERNEL
|
||||
* DEVICE TEST KERNELS
|
||||
*****************************************************************************/
|
||||
__global__ void ExtendedPrimitiveTest(int loop, int skip, uint64_t *timer,
|
||||
char *s_buf, char *r_buf, int size,
|
||||
@@ -70,6 +70,50 @@ __global__ void ExtendedPrimitiveTest(int loop, int skip, uint64_t *timer,
|
||||
roc_shmem_wg_finalize();
|
||||
}
|
||||
|
||||
__global__ void ExtendedPrimitiveTestTiled(int loop, int skip, uint64_t *timer,
|
||||
char *s_buf, char *r_buf, int size,
|
||||
TestType type,
|
||||
ShmemContextType ctx_type) {
|
||||
__shared__ roc_shmem_ctx_t ctx;
|
||||
roc_shmem_wg_init();
|
||||
roc_shmem_wg_ctx_create(ctx_type, &ctx);
|
||||
|
||||
uint64_t start;
|
||||
uint64_t idx = size * get_flat_grid_id();
|
||||
s_buf += idx;
|
||||
r_buf += idx;
|
||||
|
||||
for (int i = 0; i < loop + skip; i++) {
|
||||
if (i == skip) start = roc_shmem_timer();
|
||||
|
||||
switch (type) {
|
||||
case WGGetTestType:
|
||||
roc_shmemx_ctx_getmem_wg(ctx, r_buf, s_buf, size, 1);
|
||||
break;
|
||||
case WGGetNBITestType:
|
||||
roc_shmemx_ctx_getmem_nbi_wg(ctx, r_buf, s_buf, size, 1);
|
||||
break;
|
||||
case WGPutTestType:
|
||||
roc_shmemx_ctx_putmem_wg(ctx, r_buf, s_buf, size, 1);
|
||||
break;
|
||||
case WGPutNBITestType:
|
||||
roc_shmemx_ctx_putmem_nbi_wg(ctx, r_buf, s_buf, size, 1);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
roc_shmem_ctx_quiet(ctx);
|
||||
|
||||
if (hipThreadIdx_x == 0) {
|
||||
timer[hipBlockIdx_x] = roc_shmem_timer() - start;
|
||||
}
|
||||
|
||||
roc_shmem_wg_ctx_destroy(&ctx);
|
||||
roc_shmem_wg_finalize();
|
||||
}
|
||||
|
||||
/******************************************************************************
|
||||
* HOST TESTER CLASS METHODS
|
||||
*****************************************************************************/
|
||||
@@ -93,17 +137,23 @@ void ExtendedPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize,
|
||||
int loop, uint64_t size) {
|
||||
size_t shared_bytes = 0;
|
||||
|
||||
hipLaunchKernelGGL(ExtendedPrimitiveTest, gridSize, blockSize, shared_bytes,
|
||||
if (args.tiled){
|
||||
hipLaunchKernelGGL(ExtendedPrimitiveTestTiled, gridSize, blockSize, shared_bytes,
|
||||
stream, loop, args.skip, timer, s_buf, r_buf, size, _type,
|
||||
_shmem_context);
|
||||
}
|
||||
else {
|
||||
hipLaunchKernelGGL(ExtendedPrimitiveTest, gridSize, blockSize, shared_bytes,
|
||||
stream, loop, args.skip, timer, s_buf, r_buf, size, _type,
|
||||
_shmem_context);
|
||||
}
|
||||
|
||||
num_msgs = (loop + args.skip) * gridSize.x;
|
||||
num_timed_msgs = loop;
|
||||
num_timed_msgs = loop * gridSize.x;
|
||||
}
|
||||
|
||||
void ExtendedPrimitiveTester::verifyResults(uint64_t size) {
|
||||
int check_id = (_type == WGGetTestType || _type == WGGetNBITestType ||
|
||||
_type == WAVEGetTestType)
|
||||
int check_id = (_type == WGGetTestType || _type == WGGetNBITestType)
|
||||
? 0
|
||||
: 1;
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
#define _EXTENDED_PRIMITIVES_HPP_
|
||||
|
||||
#include "tester.hpp"
|
||||
#include "../src/util.hpp"
|
||||
|
||||
/******************************************************************************
|
||||
* HOST TESTER CLASS
|
||||
|
||||
@@ -469,19 +469,31 @@ std::vector<Tester*> Tester::create(TesterArguments args) {
|
||||
testers.push_back(new ShmemPtrTester(args));
|
||||
return testers;
|
||||
case WGGetTestType:
|
||||
if (rank == 0) std::cout << "Blocking WG level Gets***" << std::endl;
|
||||
if (rank == 0) {
|
||||
if (args.tiled) std::cout << "Tiled Blocking WG level Gets***" << std::endl;
|
||||
else std::cout << "Blocking WG level Gets***" << std::endl;
|
||||
}
|
||||
testers.push_back(new ExtendedPrimitiveTester(args));
|
||||
return testers;
|
||||
case WGGetNBITestType:
|
||||
if (rank == 0) std::cout << "Non-Blocking WG level Gets***" << std::endl;
|
||||
if (rank == 0) {
|
||||
if (args.tiled) std::cout << "Tiled Non-Blocking WG level Gets***" << std::endl;
|
||||
else std::cout << "Non-Blocking WG level Gets***" << std::endl;
|
||||
}
|
||||
testers.push_back(new ExtendedPrimitiveTester(args));
|
||||
return testers;
|
||||
case WGPutTestType:
|
||||
if (rank == 0) std::cout << "Blocking WG level Puts***" << std::endl;
|
||||
if (rank == 0) {
|
||||
if (args.tiled) std::cout << "Tiled Blocking WG level Puts***" << std::endl;
|
||||
else std::cout << "Blocking WG level Puts***" << std::endl;
|
||||
}
|
||||
testers.push_back(new ExtendedPrimitiveTester(args));
|
||||
return testers;
|
||||
case WGPutNBITestType:
|
||||
if (rank == 0) std::cout << "Non-Blocking WG level Puts***" << std::endl;
|
||||
if (rank == 0) {
|
||||
if(args.tiled) std::cout << "Tiled Non-Blocking WG level Puts***" << std::endl;
|
||||
else std::cout << "Non-Blocking WG level Puts***" << std::endl;
|
||||
}
|
||||
testers.push_back(new ExtendedPrimitiveTester(args));
|
||||
return testers;
|
||||
case PutNBIMRTestType:
|
||||
|
||||
@@ -60,6 +60,9 @@ TesterArguments::TesterArguments(int argc, char *argv[]) {
|
||||
} else if (arg == "-x") {
|
||||
i++;
|
||||
shmem_context = atoi(argv[i]);
|
||||
} else if (arg == "-ti") {
|
||||
i++;
|
||||
tiled = (atoi(argv[i]) == 1);
|
||||
} else {
|
||||
show_usage(argv[0]);
|
||||
exit(-1);
|
||||
@@ -119,6 +122,7 @@ void TesterArguments::show_usage(std::string executable_name) {
|
||||
std::cout << "\t-o <Operation type for the random_access test>\n";
|
||||
std::cout << "\t-ta <Number of Thread Accessing the communication>\n";
|
||||
std::cout << "\t-x <shmem context>\n";
|
||||
std::cout << "\t-ti <Tiled version>\n";
|
||||
}
|
||||
|
||||
void TesterArguments::get_rocshmem_arguments() {
|
||||
|
||||
@@ -58,6 +58,7 @@ class TesterArguments {
|
||||
unsigned coal_coef = 64;
|
||||
unsigned op_type = 0;
|
||||
unsigned shmem_context = rocshmem::ROC_SHMEM_CTX_WG_PRIVATE;
|
||||
bool tiled = false;
|
||||
|
||||
/**
|
||||
* Arguments obtained from rocshmem
|
||||
|
||||
Reference in New Issue
Block a user