Validate signal after put signal operations
[ROCm/rocshmem commit: 8d74c7b73e]
Этот коммит содержится в:
@@ -270,15 +270,18 @@ TestSigOps() {
|
||||
ExecTest "putsignal" 2 2 32 1048576
|
||||
ExecTest "wgputsignal" 2 2 32 1048576
|
||||
ExecTest "waveputsignal" 2 1 32 1048576
|
||||
ExecTest "waveputsignal" 2 1 64 1048576
|
||||
|
||||
ExecTest "putsignalnbi" 2 1 1 1048576
|
||||
ExecTest "putsignalnbi" 2 2 32 1048576
|
||||
ExecTest "wgputsignalnbi" 2 2 32 1048576
|
||||
ExecTest "waveputsignalnbi" 2 1 32 1048576
|
||||
ExecTest "waveputsignalnbi" 2 1 64 1048576
|
||||
|
||||
ExecTest "signalfetch" 2 1 1
|
||||
ExecTest "wgsignalfetch" 2 2 32
|
||||
ExecTest "wavesignalfetch" 2 1 32
|
||||
ExecTest "wavesignalfetch" 2 1 64
|
||||
}
|
||||
|
||||
TestColl() {
|
||||
|
||||
@@ -31,14 +31,13 @@ using namespace rocshmem;
|
||||
*****************************************************************************/
|
||||
__global__ void PutmemSignalTest(int loop, int skip, uint64_t *timer, char *s_buf,
|
||||
char *r_buf, int size, uint64_t *sig_addr,
|
||||
TestType type, ShmemContextType ctx_type) {
|
||||
TestType type, ShmemContextType ctx_type, int sig_op) {
|
||||
__shared__ rocshmem_ctx_t ctx;
|
||||
rocshmem_wg_init();
|
||||
rocshmem_wg_ctx_create(ctx_type, &ctx);
|
||||
|
||||
uint64_t start;
|
||||
uint64_t signal = 0;
|
||||
int sig_op = ROCSHMEM_SIGNAL_SET;
|
||||
uint64_t signal = 1;
|
||||
|
||||
for (int i = 0; i < loop + skip; i++) {
|
||||
if (i == skip) {
|
||||
@@ -128,6 +127,11 @@ SignalingOperationsTester::SignalingOperationsTester(TesterArguments args) : Tes
|
||||
CHECK_HIP(hipMallocManaged(&fetched_value, sizeof(uint64_t), hipMemAttachHost));
|
||||
}
|
||||
|
||||
SignalingOperationsTester::SignalingOperationsTester(TesterArguments args, int signal_op)
|
||||
: SignalingOperationsTester(args) {
|
||||
sig_op = signal_op;
|
||||
}
|
||||
|
||||
SignalingOperationsTester::~SignalingOperationsTester() {
|
||||
rocshmem_free(s_buf);
|
||||
rocshmem_free(r_buf);
|
||||
@@ -146,6 +150,7 @@ void SignalingOperationsTester::launchKernel(dim3 gridSize, dim3 blockSize, int
|
||||
uint64_t size) {
|
||||
size_t shared_bytes = 0;
|
||||
|
||||
|
||||
if ((_type == SignalFetchTestType) ||
|
||||
(_type == WAVESignalFetchTestType) ||
|
||||
(_type == WGSignalFetchTestType)) {
|
||||
@@ -154,7 +159,7 @@ void SignalingOperationsTester::launchKernel(dim3 gridSize, dim3 blockSize, int
|
||||
} else {
|
||||
hipLaunchKernelGGL(PutmemSignalTest, gridSize, blockSize, shared_bytes, stream,
|
||||
loop, args.skip, timer, s_buf, r_buf, size, sig_addr,
|
||||
_type, _shmem_context);
|
||||
_type, _shmem_context, sig_op);
|
||||
}
|
||||
|
||||
num_msgs = (loop + args.skip) * gridSize.x;
|
||||
@@ -162,35 +167,66 @@ void SignalingOperationsTester::launchKernel(dim3 gridSize, dim3 blockSize, int
|
||||
}
|
||||
|
||||
void SignalingOperationsTester::verifyResults(uint64_t size) {
|
||||
int check_data_id = (_type == PutSignalTestType ||
|
||||
_type == PutSignalNBITestType ||
|
||||
_type == WAVEPutSignalTestType ||
|
||||
_type == WAVEPutSignalNBITestType ||
|
||||
_type == WGPutSignalTestType ||
|
||||
_type == WGPutSignalNBITestType)
|
||||
? 1 : -1; // do not check if it doesn't match a test
|
||||
|
||||
int check_fetched_value_id = (_type == SignalFetchTestType ||
|
||||
_type == WAVESignalFetchTestType ||
|
||||
_type == WGSignalFetchTestType)
|
||||
? 0 : -1; // do not check if it doesn't match a test
|
||||
|
||||
if (args.myid == check_data_id) {
|
||||
for (uint64_t i = 0; i < size; i++) {
|
||||
if (r_buf[i] != '0') {
|
||||
fprintf(stderr, "Data validation error at idx %lu\n", i);
|
||||
fprintf(stderr, "Got %c, Expected %c\n", r_buf[i], '0');
|
||||
if (_type == SignalFetchTestType ||
|
||||
_type == WAVESignalFetchTestType ||
|
||||
_type == WGSignalFetchTestType) {
|
||||
if (0 == args.myid) {
|
||||
uint64_t value = *fetched_value;
|
||||
uint64_t expected_value = (args.myid + 123);
|
||||
if (value != expected_value) {
|
||||
fprintf(stderr, "Fetched Value %lu, Expected %lu\n", value, expected_value);
|
||||
exit(-1);
|
||||
}
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
if (1 == args.myid) {
|
||||
// Validate Data
|
||||
for (uint64_t i = 0; i < size; i++) {
|
||||
if (r_buf[i] != '0') {
|
||||
fprintf(stderr, "Data validation error at idx %lu\n", i);
|
||||
fprintf(stderr, "Got %c, Expected %c\n", r_buf[i], '0');
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
// Validate Signal
|
||||
if (ROCSHMEM_SIGNAL_SET == sig_op) {
|
||||
uint64_t expected_value = 1;
|
||||
uint64_t value = *sig_addr;
|
||||
|
||||
if (value != expected_value) {
|
||||
fprintf(stderr, "ROCSHMEM_SIGNAL_SET Value %lu, Expected %lu\n", value, expected_value);
|
||||
exit(-1);
|
||||
}
|
||||
} else if (ROCSHMEM_SIGNAL_ADD == sig_op) {
|
||||
uint64_t value = *sig_addr;
|
||||
uint64_t expected_value = (args.myid + 123); // Initial Value
|
||||
uint64_t num_waves = 1;
|
||||
|
||||
switch (_type) {
|
||||
case PutSignalTestType:
|
||||
case PutSignalNBITestType:
|
||||
expected_value += ((args.skip + args.loop) * args.wg_size * args.num_wgs);
|
||||
break;
|
||||
case WGPutSignalTestType:
|
||||
case WGPutSignalNBITestType:
|
||||
expected_value += ((args.skip + args.loop) * args.num_wgs);
|
||||
break;
|
||||
case WAVEPutSignalTestType:
|
||||
case WAVEPutSignalNBITestType:
|
||||
num_waves = max(1, (args.num_wgs / __AMDGCN_WAVEFRONT_SIZE__));
|
||||
expected_value += ((args.skip + args.loop) * num_waves);
|
||||
break;
|
||||
default:
|
||||
fprintf(stderr, "Invalid Test\n");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (value != expected_value) {
|
||||
fprintf(stderr, "ROCSHMEM_SIGNAL_ADD Value %lu, Expected %lu\n", value, expected_value);
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (args.myid == check_fetched_value_id) {
|
||||
uint64_t value = *fetched_value;
|
||||
uint64_t expected_value = (args.myid + 123);
|
||||
if (value != expected_value) {
|
||||
fprintf(stderr, "Fetched Value %lu, Expected %lu\n", value, expected_value);
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,6 +31,7 @@
|
||||
class SignalingOperationsTester : public Tester {
|
||||
public:
|
||||
explicit SignalingOperationsTester(TesterArguments args);
|
||||
explicit SignalingOperationsTester(TesterArguments args, int signal_op);
|
||||
virtual ~SignalingOperationsTester();
|
||||
|
||||
protected:
|
||||
@@ -41,6 +42,7 @@ class SignalingOperationsTester : public Tester {
|
||||
|
||||
virtual void verifyResults(uint64_t size) override;
|
||||
|
||||
int sig_op;
|
||||
char *s_buf = nullptr;
|
||||
char *r_buf = nullptr;
|
||||
uint64_t *sig_addr;
|
||||
|
||||
@@ -422,27 +422,33 @@ std::vector<Tester*> Tester::create(TesterArguments args) {
|
||||
return testers;
|
||||
case PutSignalTestType:
|
||||
if (rank == 0) std::cout << "Putmem Signal ###" << std::endl;
|
||||
testers.push_back(new SignalingOperationsTester(args));
|
||||
testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_SET));
|
||||
testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_ADD));
|
||||
return testers;
|
||||
case WGPutSignalTestType:
|
||||
if (rank == 0) std::cout << "WG Putmem Signal ###" << std::endl;
|
||||
testers.push_back(new SignalingOperationsTester(args));
|
||||
testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_SET));
|
||||
testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_ADD));
|
||||
return testers;
|
||||
case WAVEPutSignalTestType:
|
||||
if (rank == 0) std::cout << "Wave Putmem Signal ###" << std::endl;
|
||||
testers.push_back(new SignalingOperationsTester(args));
|
||||
testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_SET));
|
||||
testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_ADD));
|
||||
return testers;
|
||||
case PutSignalNBITestType:
|
||||
if (rank == 0) std::cout << "Non-Blocking Putmem Signal ###" << std::endl;
|
||||
testers.push_back(new SignalingOperationsTester(args));
|
||||
testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_SET));
|
||||
testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_ADD));
|
||||
return testers;
|
||||
case WGPutSignalNBITestType:
|
||||
if (rank == 0) std::cout << "Non-Blocking WG Putmem Signal ###" << std::endl;
|
||||
testers.push_back(new SignalingOperationsTester(args));
|
||||
testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_SET));
|
||||
testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_ADD));
|
||||
return testers;
|
||||
case WAVEPutSignalNBITestType:
|
||||
if (rank == 0) std::cout << "Non-Blocking Wave Putmem Signal ###" << std::endl;
|
||||
testers.push_back(new SignalingOperationsTester(args));
|
||||
testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_SET));
|
||||
testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_ADD));
|
||||
return testers;
|
||||
case SignalFetchTestType:
|
||||
if (rank == 0) std::cout << "Signal Fetch ###" << std::endl;
|
||||
|
||||
Ссылка в новой задаче
Block a user