Validate signal after put signal operations

[ROCm/rocshmem commit: 8d74c7b73e]
Этот коммит содержится в:
Yiltan Hassan Temucin
2025-02-06 08:16:04 -06:00
родитель 9317172fab
Коммит 257610bdc5
4 изменённых файлов: 84 добавлений и 37 удалений
+3
Просмотреть файл
@@ -270,15 +270,18 @@ TestSigOps() {
ExecTest "putsignal" 2 2 32 1048576
ExecTest "wgputsignal" 2 2 32 1048576
ExecTest "waveputsignal" 2 1 32 1048576
ExecTest "waveputsignal" 2 1 64 1048576
ExecTest "putsignalnbi" 2 1 1 1048576
ExecTest "putsignalnbi" 2 2 32 1048576
ExecTest "wgputsignalnbi" 2 2 32 1048576
ExecTest "waveputsignalnbi" 2 1 32 1048576
ExecTest "waveputsignalnbi" 2 1 64 1048576
ExecTest "signalfetch" 2 1 1
ExecTest "wgsignalfetch" 2 2 32
ExecTest "wavesignalfetch" 2 1 32
ExecTest "wavesignalfetch" 2 1 64
}
TestColl() {
+67 -31
Просмотреть файл
@@ -31,14 +31,13 @@ using namespace rocshmem;
*****************************************************************************/
__global__ void PutmemSignalTest(int loop, int skip, uint64_t *timer, char *s_buf,
char *r_buf, int size, uint64_t *sig_addr,
TestType type, ShmemContextType ctx_type) {
TestType type, ShmemContextType ctx_type, int sig_op) {
__shared__ rocshmem_ctx_t ctx;
rocshmem_wg_init();
rocshmem_wg_ctx_create(ctx_type, &ctx);
uint64_t start;
uint64_t signal = 0;
int sig_op = ROCSHMEM_SIGNAL_SET;
uint64_t signal = 1;
for (int i = 0; i < loop + skip; i++) {
if (i == skip) {
@@ -128,6 +127,11 @@ SignalingOperationsTester::SignalingOperationsTester(TesterArguments args) : Tes
CHECK_HIP(hipMallocManaged(&fetched_value, sizeof(uint64_t), hipMemAttachHost));
}
SignalingOperationsTester::SignalingOperationsTester(TesterArguments args, int signal_op)
: SignalingOperationsTester(args) {
sig_op = signal_op;
}
SignalingOperationsTester::~SignalingOperationsTester() {
rocshmem_free(s_buf);
rocshmem_free(r_buf);
@@ -146,6 +150,7 @@ void SignalingOperationsTester::launchKernel(dim3 gridSize, dim3 blockSize, int
uint64_t size) {
size_t shared_bytes = 0;
if ((_type == SignalFetchTestType) ||
(_type == WAVESignalFetchTestType) ||
(_type == WGSignalFetchTestType)) {
@@ -154,7 +159,7 @@ void SignalingOperationsTester::launchKernel(dim3 gridSize, dim3 blockSize, int
} else {
hipLaunchKernelGGL(PutmemSignalTest, gridSize, blockSize, shared_bytes, stream,
loop, args.skip, timer, s_buf, r_buf, size, sig_addr,
_type, _shmem_context);
_type, _shmem_context, sig_op);
}
num_msgs = (loop + args.skip) * gridSize.x;
@@ -162,35 +167,66 @@ void SignalingOperationsTester::launchKernel(dim3 gridSize, dim3 blockSize, int
}
void SignalingOperationsTester::verifyResults(uint64_t size) {
int check_data_id = (_type == PutSignalTestType ||
_type == PutSignalNBITestType ||
_type == WAVEPutSignalTestType ||
_type == WAVEPutSignalNBITestType ||
_type == WGPutSignalTestType ||
_type == WGPutSignalNBITestType)
? 1 : -1; // do not check if it doesn't match a test
int check_fetched_value_id = (_type == SignalFetchTestType ||
_type == WAVESignalFetchTestType ||
_type == WGSignalFetchTestType)
? 0 : -1; // do not check if it doesn't match a test
if (args.myid == check_data_id) {
for (uint64_t i = 0; i < size; i++) {
if (r_buf[i] != '0') {
fprintf(stderr, "Data validation error at idx %lu\n", i);
fprintf(stderr, "Got %c, Expected %c\n", r_buf[i], '0');
if (_type == SignalFetchTestType ||
_type == WAVESignalFetchTestType ||
_type == WGSignalFetchTestType) {
if (0 == args.myid) {
uint64_t value = *fetched_value;
uint64_t expected_value = (args.myid + 123);
if (value != expected_value) {
fprintf(stderr, "Fetched Value %lu, Expected %lu\n", value, expected_value);
exit(-1);
}
return;
}
} else {
if (1 == args.myid) {
// Validate Data
for (uint64_t i = 0; i < size; i++) {
if (r_buf[i] != '0') {
fprintf(stderr, "Data validation error at idx %lu\n", i);
fprintf(stderr, "Got %c, Expected %c\n", r_buf[i], '0');
exit(-1);
}
}
// Validate Signal
if (ROCSHMEM_SIGNAL_SET == sig_op) {
uint64_t expected_value = 1;
uint64_t value = *sig_addr;
if (value != expected_value) {
fprintf(stderr, "ROCSHMEM_SIGNAL_SET Value %lu, Expected %lu\n", value, expected_value);
exit(-1);
}
} else if (ROCSHMEM_SIGNAL_ADD == sig_op) {
uint64_t value = *sig_addr;
uint64_t expected_value = (args.myid + 123); // Initial Value
uint64_t num_waves = 1;
switch (_type) {
case PutSignalTestType:
case PutSignalNBITestType:
expected_value += ((args.skip + args.loop) * args.wg_size * args.num_wgs);
break;
case WGPutSignalTestType:
case WGPutSignalNBITestType:
expected_value += ((args.skip + args.loop) * args.num_wgs);
break;
case WAVEPutSignalTestType:
case WAVEPutSignalNBITestType:
num_waves = max(1, (args.num_wgs / __AMDGCN_WAVEFRONT_SIZE__));
expected_value += ((args.skip + args.loop) * num_waves);
break;
default:
fprintf(stderr, "Invalid Test\n");
exit(-1);
}
if (value != expected_value) {
fprintf(stderr, "ROCSHMEM_SIGNAL_ADD Value %lu, Expected %lu\n", value, expected_value);
exit(-1);
}
}
}
}
if (args.myid == check_fetched_value_id) {
uint64_t value = *fetched_value;
uint64_t expected_value = (args.myid + 123);
if (value != expected_value) {
fprintf(stderr, "Fetched Value %lu, Expected %lu\n", value, expected_value);
exit(-1);
}
}
}
+2
Просмотреть файл
@@ -31,6 +31,7 @@
class SignalingOperationsTester : public Tester {
public:
explicit SignalingOperationsTester(TesterArguments args);
explicit SignalingOperationsTester(TesterArguments args, int signal_op);
virtual ~SignalingOperationsTester();
protected:
@@ -41,6 +42,7 @@ class SignalingOperationsTester : public Tester {
virtual void verifyResults(uint64_t size) override;
int sig_op;
char *s_buf = nullptr;
char *r_buf = nullptr;
uint64_t *sig_addr;
+12 -6
Просмотреть файл
@@ -422,27 +422,33 @@ std::vector<Tester*> Tester::create(TesterArguments args) {
return testers;
case PutSignalTestType:
if (rank == 0) std::cout << "Putmem Signal ###" << std::endl;
testers.push_back(new SignalingOperationsTester(args));
testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_SET));
testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_ADD));
return testers;
case WGPutSignalTestType:
if (rank == 0) std::cout << "WG Putmem Signal ###" << std::endl;
testers.push_back(new SignalingOperationsTester(args));
testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_SET));
testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_ADD));
return testers;
case WAVEPutSignalTestType:
if (rank == 0) std::cout << "Wave Putmem Signal ###" << std::endl;
testers.push_back(new SignalingOperationsTester(args));
testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_SET));
testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_ADD));
return testers;
case PutSignalNBITestType:
if (rank == 0) std::cout << "Non-Blocking Putmem Signal ###" << std::endl;
testers.push_back(new SignalingOperationsTester(args));
testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_SET));
testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_ADD));
return testers;
case WGPutSignalNBITestType:
if (rank == 0) std::cout << "Non-Blocking WG Putmem Signal ###" << std::endl;
testers.push_back(new SignalingOperationsTester(args));
testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_SET));
testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_ADD));
return testers;
case WAVEPutSignalNBITestType:
if (rank == 0) std::cout << "Non-Blocking Wave Putmem Signal ###" << std::endl;
testers.push_back(new SignalingOperationsTester(args));
testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_SET));
testers.push_back(new SignalingOperationsTester(args, ROCSHMEM_SIGNAL_ADD));
return testers;
case SignalFetchTestType:
if (rank == 0) std::cout << "Signal Fetch ###" << std::endl;