Add tilled version of puts and gets at wavefront level to the functional test suite
* Implemented tiled version of put*_wave and get*_wave functions
* Maintain single kernel that supports both tiled and untiled versions
* Disable IPC in the default RO build script
[ROCm/rocshmem commit: b6d31ac7ef]
Этот коммит содержится в:
@@ -17,8 +17,9 @@ cmake \
|
||||
-DDEBUG=OFF \
|
||||
-DPROFILE=OFF \
|
||||
-DUSE_GPU_IB=OFF \
|
||||
-DUSE_RO=ON \
|
||||
-DUSE_DC=OFF \
|
||||
-DUSE_IPC=ON \
|
||||
-DUSE_IPC=OFF \
|
||||
-DUSE_THREADS=ON \
|
||||
-DUSE_WF_COAL=OFF \
|
||||
-DUSE_COHERENT_HEAP=ON \
|
||||
|
||||
@@ -68,16 +68,16 @@ case $2 in
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 31 > $3/wg_putnbi_n2_w1_z1_1MB.log
|
||||
check wg_putnbi_n2_w1_z1_1MB
|
||||
echo "wg_get_tiled_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 28 -ti 1 > $3/wg_get_tiled_n2_w1_z1_1MB.log
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 64 -s 1048576 -a 28 > $3/wg_get_tiled_n2_w1_z1_1MB.log
|
||||
check wg_get_tiled_n2_w1_z1_1MB
|
||||
echo "wg_getnbi_tiled_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 29 -ti 1 > $3/wg_getnbi_tiled_n2_w1_z1_1MB.log
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 64 -s 1048576 -a 29 > $3/wg_getnbi_tiled_n2_w1_z1_1MB.log
|
||||
check wg_getnbi_tiled_n2_w1_z1_1MB
|
||||
echo "wg_put_tiled_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 30 -ti 1 > $3/wg_put_tiled_n2_w1_z1_1MB.log
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 64 -s 1048576 -a 30 > $3/wg_put_tiled_n2_w1_z1_1MB.log
|
||||
check wg_put_tiled_n2_w1_z1_1MB
|
||||
echo "wg_putnbi_tiled_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 31 -ti 1 > $3/wg_putnbi_tiled_n2_w1_z1_1MB.log
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 64 -s 1048576 -a 31 > $3/wg_putnbi_tiled_n2_w1_z1_1MB.log
|
||||
check wg_putnbi_tiled_n2_w1_z1_1MB
|
||||
echo "wave_get_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 32 > $3/wave_get_n2_w1_z1_1MB.log
|
||||
@@ -91,6 +91,18 @@ case $2 in
|
||||
echo "wave_putnbi_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 64 -s 1048576 -a 35 > $3/wave_putnbi_n2_w1_z1_1MB.log
|
||||
check wave_putnbi_n2_w1_z1_1MB
|
||||
echo "wave_get_tiled_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 128 -s 1048576 -a 32 > $3/wave_get_tiled_n2_w1_z1_1MB.log
|
||||
check wave_get_tiled_n2_w1_z1_1MB
|
||||
echo "wave_getnbi_tiled_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 128 -s 1048576 -a 33 > $3/wave_getnbi_tiled_n2_w1_z1_1MB.log
|
||||
check wave_getnbi_tiled_n2_w1_z1_1MB
|
||||
echo "wave_put_tiled_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 128 -s 1048576 -a 34 > $3/wave_put_tiled_n2_w1_z1_1MB.log
|
||||
check wave_put_tiled_n2_w1_z1_1MB
|
||||
echo "wave_putnbi_tiled_n2_w1_z1_1MB"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=2 mpirun -np 2 $1 -w 2 -z 128 -s 1048576 -a 35 > $3/wave_putnbi_tiled_n2_w1_z1_1MB.log
|
||||
check wave_putnbi_tiled_n2_w1_z1_1MB
|
||||
echo "amofadd_n2_w1_z1"
|
||||
ROC_SHMEM_MAX_NUM_CONTEXTS=1 mpirun -np 2 $1 -w 1 -z 1 -a 6 > $3/amofadd_n2_w1_z1.log
|
||||
check amofadd_n2_w1_z1
|
||||
|
||||
@@ -37,47 +37,6 @@ __global__ void ExtendedPrimitiveTest(int loop, int skip, uint64_t *timer,
|
||||
roc_shmem_wg_init();
|
||||
roc_shmem_wg_ctx_create(ctx_type, &ctx);
|
||||
|
||||
uint64_t start;
|
||||
|
||||
for (int i = 0; i < loop + skip; i++) {
|
||||
if (i == skip) start = roc_shmem_timer();
|
||||
|
||||
switch (type) {
|
||||
case WGGetTestType:
|
||||
roc_shmemx_ctx_getmem_wg(ctx, r_buf, s_buf, size, 1);
|
||||
break;
|
||||
case WGGetNBITestType:
|
||||
roc_shmemx_ctx_getmem_nbi_wg(ctx, r_buf, s_buf, size, 1);
|
||||
break;
|
||||
case WGPutTestType:
|
||||
roc_shmemx_ctx_putmem_wg(ctx, r_buf, s_buf, size, 1);
|
||||
break;
|
||||
case WGPutNBITestType:
|
||||
roc_shmemx_ctx_putmem_nbi_wg(ctx, r_buf, s_buf, size, 1);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
roc_shmem_ctx_quiet(ctx);
|
||||
|
||||
if (hipThreadIdx_x == 0) {
|
||||
timer[hipBlockIdx_x] = roc_shmem_timer() - start;
|
||||
}
|
||||
|
||||
roc_shmem_wg_ctx_destroy(&ctx);
|
||||
roc_shmem_wg_finalize();
|
||||
}
|
||||
|
||||
__global__ void ExtendedPrimitiveTestTiled(int loop, int skip, uint64_t *timer,
|
||||
char *s_buf, char *r_buf, int size,
|
||||
TestType type,
|
||||
ShmemContextType ctx_type) {
|
||||
__shared__ roc_shmem_ctx_t ctx;
|
||||
roc_shmem_wg_init();
|
||||
roc_shmem_wg_ctx_create(ctx_type, &ctx);
|
||||
|
||||
uint64_t start;
|
||||
uint64_t idx = size * get_flat_grid_id();
|
||||
s_buf += idx;
|
||||
@@ -119,8 +78,8 @@ __global__ void ExtendedPrimitiveTestTiled(int loop, int skip, uint64_t *timer,
|
||||
*****************************************************************************/
|
||||
ExtendedPrimitiveTester::ExtendedPrimitiveTester(TesterArguments args)
|
||||
: Tester(args) {
|
||||
s_buf = (char *)roc_shmem_malloc(args.max_msg_size * args.wg_size);
|
||||
r_buf = (char *)roc_shmem_malloc(args.max_msg_size * args.wg_size);
|
||||
s_buf = (char *)roc_shmem_malloc(args.max_msg_size * args.num_wgs);
|
||||
r_buf = (char *)roc_shmem_malloc(args.max_msg_size * args.num_wgs);
|
||||
}
|
||||
|
||||
ExtendedPrimitiveTester::~ExtendedPrimitiveTester() {
|
||||
@@ -129,24 +88,17 @@ ExtendedPrimitiveTester::~ExtendedPrimitiveTester() {
|
||||
}
|
||||
|
||||
void ExtendedPrimitiveTester::resetBuffers(uint64_t size) {
|
||||
memset(s_buf, '0', args.max_msg_size * args.wg_size);
|
||||
memset(r_buf, '1', args.max_msg_size * args.wg_size);
|
||||
memset(s_buf, '0', size * args.num_wgs);
|
||||
memset(r_buf, '1', size * args.num_wgs);
|
||||
}
|
||||
|
||||
void ExtendedPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize,
|
||||
int loop, uint64_t size) {
|
||||
size_t shared_bytes = 0;
|
||||
|
||||
if (args.tiled){
|
||||
hipLaunchKernelGGL(ExtendedPrimitiveTestTiled, gridSize, blockSize, shared_bytes,
|
||||
hipLaunchKernelGGL(ExtendedPrimitiveTest, gridSize, blockSize, shared_bytes,
|
||||
stream, loop, args.skip, timer, s_buf, r_buf, size, _type,
|
||||
_shmem_context);
|
||||
}
|
||||
else {
|
||||
hipLaunchKernelGGL(ExtendedPrimitiveTest, gridSize, blockSize, shared_bytes,
|
||||
stream, loop, args.skip, timer, s_buf, r_buf, size, _type,
|
||||
_shmem_context);
|
||||
}
|
||||
|
||||
num_msgs = (loop + args.skip) * gridSize.x;
|
||||
num_timed_msgs = loop * gridSize.x;
|
||||
@@ -158,7 +110,7 @@ void ExtendedPrimitiveTester::verifyResults(uint64_t size) {
|
||||
: 1;
|
||||
|
||||
if (args.myid == check_id) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
for (int i = 0; i < size * args.num_wgs; i++) {
|
||||
if (r_buf[i] != '0') {
|
||||
fprintf(stderr, "Data validation error at idx %d\n", i);
|
||||
fprintf(stderr, "Got %c, Expected %c \n", r_buf[i], '0');
|
||||
|
||||
@@ -57,6 +57,9 @@
|
||||
Tester::Tester(TesterArguments args) : args(args) {
|
||||
_type = (TestType)args.algorithm;
|
||||
_shmem_context = args.shmem_context;
|
||||
CHECK_HIP(hipGetDevice(&device_id));
|
||||
CHECK_HIP(hipGetDeviceProperties(&deviceProps, device_id));
|
||||
num_warps = args.wg_size / deviceProps.warpSize;
|
||||
CHECK_HIP(hipStreamCreate(&stream));
|
||||
CHECK_HIP(hipEventCreate(&start_event));
|
||||
CHECK_HIP(hipEventCreate(&stop_event));
|
||||
@@ -470,28 +473,32 @@ std::vector<Tester*> Tester::create(TesterArguments args) {
|
||||
return testers;
|
||||
case WGGetTestType:
|
||||
if (rank == 0) {
|
||||
if (args.tiled) std::cout << "Tiled Blocking WG level Gets***" << std::endl;
|
||||
if (args.num_wgs > 1)
|
||||
std::cout << "Tiled Blocking WG level Gets***" << std::endl;
|
||||
else std::cout << "Blocking WG level Gets***" << std::endl;
|
||||
}
|
||||
testers.push_back(new ExtendedPrimitiveTester(args));
|
||||
return testers;
|
||||
case WGGetNBITestType:
|
||||
if (rank == 0) {
|
||||
if (args.tiled) std::cout << "Tiled Non-Blocking WG level Gets***" << std::endl;
|
||||
if (args.num_wgs > 1)
|
||||
std::cout << "Tiled Non-Blocking WG level Gets***" << std::endl;
|
||||
else std::cout << "Non-Blocking WG level Gets***" << std::endl;
|
||||
}
|
||||
testers.push_back(new ExtendedPrimitiveTester(args));
|
||||
return testers;
|
||||
case WGPutTestType:
|
||||
if (rank == 0) {
|
||||
if (args.tiled) std::cout << "Tiled Blocking WG level Puts***" << std::endl;
|
||||
if (args.num_wgs > 1)
|
||||
std::cout << "Tiled Blocking WG level Puts***" << std::endl;
|
||||
else std::cout << "Blocking WG level Puts***" << std::endl;
|
||||
}
|
||||
testers.push_back(new ExtendedPrimitiveTester(args));
|
||||
return testers;
|
||||
case WGPutNBITestType:
|
||||
if (rank == 0) {
|
||||
if(args.tiled) std::cout << "Tiled Non-Blocking WG level Puts***" << std::endl;
|
||||
if (args.num_wgs > 1)
|
||||
std::cout << "Tiled Non-Blocking WG level Puts***" << std::endl;
|
||||
else std::cout << "Non-Blocking WG level Puts***" << std::endl;
|
||||
}
|
||||
testers.push_back(new ExtendedPrimitiveTester(args));
|
||||
@@ -502,19 +509,35 @@ std::vector<Tester*> Tester::create(TesterArguments args) {
|
||||
testers.push_back(new PrimitiveMRTester(args));
|
||||
return testers;
|
||||
case WAVEGetTestType:
|
||||
if (rank == 0) std::cout << "WAVE Blocking Gets***" << std::endl;
|
||||
if (rank == 0) {
|
||||
if (args.num_wgs > 1 || args.wg_size / 64 > 1)
|
||||
std::cout << "Tiled Blocking WAVE level Gets***" << std::endl;
|
||||
else std::cout << "Blocking WAVE level Gets***" << std::endl;
|
||||
}
|
||||
testers.push_back(new WaveLevelPrimitiveTester(args));
|
||||
return testers;
|
||||
case WAVEGetNBITestType:
|
||||
if (rank == 0) std::cout << "WAVE Non-Blocking Gets***" << std::endl;
|
||||
if (rank == 0) {
|
||||
if (args.num_wgs > 1 || args.wg_size / 64 > 1)
|
||||
std::cout << "Tiled Non-Blocking WAVE level Gets***" << std::endl;
|
||||
else std::cout << "Non-Blocking WAVE level Gets***" << std::endl;
|
||||
}
|
||||
testers.push_back(new WaveLevelPrimitiveTester(args));
|
||||
return testers;
|
||||
case WAVEPutTestType:
|
||||
if (rank == 0) std::cout << "WAVE Blocking Puts***" << std::endl;
|
||||
if (rank == 0) {
|
||||
if (args.num_wgs > 1 || args.wg_size / 64 > 1)
|
||||
std::cout << "Tiled Blocking WAVE level Puts***" << std::endl;
|
||||
else std::cout << "Blocking WAVE level Puts***" << std::endl;
|
||||
}
|
||||
testers.push_back(new WaveLevelPrimitiveTester(args));
|
||||
return testers;
|
||||
case WAVEPutNBITestType:
|
||||
if (rank == 0) std::cout << "WAVE Non-Blocking Puts***" << std::endl;
|
||||
if (rank == 0) {
|
||||
if (args.num_wgs > 1 || args.wg_size / 64 > 1)
|
||||
std::cout << "Tiled Non-Blocking WAVE level Puts***" << std::endl;
|
||||
else std::cout << "Non-Blocking WAVE level Puts***" << std::endl;
|
||||
}
|
||||
testers.push_back(new WaveLevelPrimitiveTester(args));
|
||||
return testers;
|
||||
default:
|
||||
@@ -684,6 +707,8 @@ uint64_t Tester::gpuCyclesToMicroseconds(uint64_t cycles) {
|
||||
uint64_t Tester::timerAvgInMicroseconds() {
|
||||
uint64_t sum = 0;
|
||||
|
||||
//TODO: Modify the calcuation for the Tiled version of puts and gets at
|
||||
// wavefront level (bpotter/avinash)
|
||||
for (int i = 0; i < args.num_wgs; i++) {
|
||||
sum += gpuCyclesToMicroseconds(timer[i]);
|
||||
}
|
||||
|
||||
@@ -117,7 +117,9 @@ class Tester {
|
||||
|
||||
int num_msgs = 0;
|
||||
int num_timed_msgs = 0;
|
||||
int num_warps = 0;
|
||||
int bw_factor = 1;
|
||||
int device_id = 0;
|
||||
|
||||
TesterArguments args;
|
||||
|
||||
@@ -125,6 +127,7 @@ class Tester {
|
||||
ShmemContextType _shmem_context = 8; // SHMEM_CTX_WP_PRIVATE
|
||||
|
||||
hipStream_t stream;
|
||||
hipDeviceProp_t deviceProps;
|
||||
|
||||
uint64_t *timer = nullptr;
|
||||
|
||||
|
||||
@@ -60,9 +60,6 @@ TesterArguments::TesterArguments(int argc, char *argv[]) {
|
||||
} else if (arg == "-x") {
|
||||
i++;
|
||||
shmem_context = atoi(argv[i]);
|
||||
} else if (arg == "-ti") {
|
||||
i++;
|
||||
tiled = (atoi(argv[i]) == 1);
|
||||
} else {
|
||||
show_usage(argv[0]);
|
||||
exit(-1);
|
||||
@@ -122,7 +119,6 @@ void TesterArguments::show_usage(std::string executable_name) {
|
||||
std::cout << "\t-o <Operation type for the random_access test>\n";
|
||||
std::cout << "\t-ta <Number of Thread Accessing the communication>\n";
|
||||
std::cout << "\t-x <shmem context>\n";
|
||||
std::cout << "\t-ti <Tiled version>\n";
|
||||
}
|
||||
|
||||
void TesterArguments::get_rocshmem_arguments() {
|
||||
|
||||
@@ -58,7 +58,6 @@ class TesterArguments {
|
||||
unsigned coal_coef = 64;
|
||||
unsigned op_type = 0;
|
||||
unsigned shmem_context = rocshmem::ROC_SHMEM_CTX_WG_PRIVATE;
|
||||
bool tiled = false;
|
||||
|
||||
/**
|
||||
* Arguments obtained from rocshmem
|
||||
|
||||
@@ -27,17 +27,22 @@
|
||||
using namespace rocshmem;
|
||||
|
||||
/******************************************************************************
|
||||
* DEVICE TEST KERNEL
|
||||
* DEVICE TEST KERNELS
|
||||
*****************************************************************************/
|
||||
__global__ void WaveLevelPrimitiveTest(int loop, int skip, uint64_t *timer,
|
||||
char *s_buf, char *r_buf, int size,
|
||||
TestType type,
|
||||
ShmemContextType ctx_type) {
|
||||
TestType type, ShmemContextType ctx_type,
|
||||
int wf_size) {
|
||||
__shared__ roc_shmem_ctx_t ctx;
|
||||
roc_shmem_wg_init();
|
||||
roc_shmem_wg_ctx_create(ctx_type, &ctx);
|
||||
|
||||
uint64_t start;
|
||||
int wf_id = get_flat_block_id() / wf_size;
|
||||
int offset = size * get_flat_grid_id() * (get_flat_block_size() / wf_size);
|
||||
int idx = wf_id * size + offset;
|
||||
s_buf += idx;
|
||||
r_buf += idx;
|
||||
|
||||
for (int i = 0; i < loop + skip; i++) {
|
||||
if (i == skip) start = roc_shmem_timer();
|
||||
@@ -75,8 +80,10 @@ __global__ void WaveLevelPrimitiveTest(int loop, int skip, uint64_t *timer,
|
||||
*****************************************************************************/
|
||||
WaveLevelPrimitiveTester::WaveLevelPrimitiveTester(TesterArguments args)
|
||||
: Tester(args) {
|
||||
s_buf = (char *)roc_shmem_malloc(args.max_msg_size * args.wg_size);
|
||||
r_buf = (char *)roc_shmem_malloc(args.max_msg_size * args.wg_size);
|
||||
s_buf = (char *)roc_shmem_malloc(args.max_msg_size * args.num_wgs
|
||||
* num_warps);
|
||||
r_buf = (char *)roc_shmem_malloc(args.max_msg_size * args.num_wgs
|
||||
* num_warps);
|
||||
}
|
||||
|
||||
WaveLevelPrimitiveTester::~WaveLevelPrimitiveTester() {
|
||||
@@ -85,8 +92,8 @@ WaveLevelPrimitiveTester::~WaveLevelPrimitiveTester() {
|
||||
}
|
||||
|
||||
void WaveLevelPrimitiveTester::resetBuffers(uint64_t size) {
|
||||
memset(s_buf, '0', args.max_msg_size * args.wg_size);
|
||||
memset(r_buf, '1', args.max_msg_size * args.wg_size);
|
||||
memset(s_buf, '0', size * args.num_wgs * num_warps);
|
||||
memset(r_buf, '1', size * args.num_wgs * num_warps);
|
||||
}
|
||||
|
||||
void WaveLevelPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize,
|
||||
@@ -95,10 +102,10 @@ void WaveLevelPrimitiveTester::launchKernel(dim3 gridSize, dim3 blockSize,
|
||||
|
||||
hipLaunchKernelGGL(WaveLevelPrimitiveTest, gridSize, blockSize, shared_bytes,
|
||||
stream, loop, args.skip, timer, s_buf, r_buf, size, _type,
|
||||
_shmem_context);
|
||||
_shmem_context, deviceProps.warpSize);
|
||||
|
||||
num_msgs = (loop + args.skip) * gridSize.x;
|
||||
num_timed_msgs = loop;
|
||||
num_timed_msgs = loop * gridSize.x;
|
||||
}
|
||||
|
||||
void WaveLevelPrimitiveTester::verifyResults(uint64_t size) {
|
||||
@@ -107,7 +114,7 @@ void WaveLevelPrimitiveTester::verifyResults(uint64_t size) {
|
||||
: 1;
|
||||
|
||||
if (args.myid == check_id) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
for (int i = 0; i < size * args.num_wgs * num_warps; i++) {
|
||||
if (r_buf[i] != '0') {
|
||||
fprintf(stderr, "Data validation error at idx %d\n", i);
|
||||
fprintf(stderr, "Got %c, Expected %c \n", r_buf[i], '0');
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
#define _WAVE_LEVEL_PRIMITIVE_TEST_HPP_
|
||||
|
||||
#include "tester.hpp"
|
||||
#include "../src/util.hpp"
|
||||
|
||||
/******************************************************************************
|
||||
* HOST TESTER CLASS
|
||||
|
||||
Ссылка в новой задаче
Block a user