From 56eb68bc4a27d201c3859e8f80288a7ac50eceae Mon Sep 17 00:00:00 2001 From: Edgar Gabriel Date: Fri, 1 Aug 2025 08:50:14 -0500 Subject: [PATCH] Add extended team tests (#207) Create teams in the functional test that are not a duplicate of the ROCSHMEM_TEAM_WORLD. THis commit contains only infra-tests to make sure that n_pes and my_pe on the new teams is indeed correct. [ROCm/rocshmem commit: e95360961dfa5575a49fe79abdf41ab9e4aeb543] --- .../scripts/functional_tests/driver.sh | 10 +- .../team_ctx_infra_tester.cpp | 162 ++++++++++++++---- .../team_ctx_infra_tester.hpp | 5 + .../tests/functional_tests/tester.cpp | 22 ++- .../tests/functional_tests/tester.hpp | 3 + .../functional_tests/tester_arguments.cpp | 6 +- .../functional_tests/tester_arguments.hpp | 12 +- 7 files changed, 187 insertions(+), 33 deletions(-) diff --git a/projects/rocshmem/scripts/functional_tests/driver.sh b/projects/rocshmem/scripts/functional_tests/driver.sh index f164ab904f..3c4902134f 100755 --- a/projects/rocshmem/scripts/functional_tests/driver.sh +++ b/projects/rocshmem/scripts/functional_tests/driver.sh @@ -106,6 +106,9 @@ declare -A TEST_NUMBERS=( ["teamwavebarrier"]="70" ["wavesync"]="71" ["wgsync"]="72" + ["teamctxsingleinfra"]="73" + ["teamctxblockinfra"]="74" + ["teamctxoddeveninfra"]="75" ) ExecTest() { @@ -422,7 +425,12 @@ TestOther() { # This test requires more contexts than workgroups export ROCSHMEM_MAX_NUM_CONTEXTS=1024 - ExecTest "teamctxinfra" 2 1 1 + ExecTest "teamctxinfra" 2 1 1 + ExecTest "teamctxsingleinfra" 2 1 1 + ExecTest "teamctxblockinfra" 4 1 1 + ExecTest "teamctxblockinfra" 5 1 1 + ExecTest "teamctxoddeveninfra" 4 1 1 + ExecTest "teamctxoddeveninfra" 5 1 1 unset ROCSHMEM_MAX_NUM_CONTEXTS } diff --git a/projects/rocshmem/tests/functional_tests/team_ctx_infra_tester.cpp b/projects/rocshmem/tests/functional_tests/team_ctx_infra_tester.cpp index c1af21a4c4..aeeb479ae9 100644 --- a/projects/rocshmem/tests/functional_tests/team_ctx_infra_tester.cpp +++ b/projects/rocshmem/tests/functional_tests/team_ctx_infra_tester.cpp @@ -39,7 +39,35 @@ rocshmem_team_t team_world_dup[NUM_TEAMS]; /****************************************************************************** * DEVICE TEST KERNEL *****************************************************************************/ -__global__ void TeamCtxInfraTest(ShmemContextType ctx_type, + __global__ void TeamCtxInfraSimpleTest(ShmemContextType ctx_type, + rocshmem_team_t team, + int expected_pe, int expected_n_pes) { + __shared__ rocshmem_ctx_t ctx; + + rocshmem_wg_init(); + rocshmem_wg_team_create_ctx(team, ctx_type, &ctx); + + int num_pes = rocshmem_ctx_n_pes(ctx); + int my_pe = rocshmem_ctx_my_pe(ctx); + + if (my_pe != expected_pe) { + printf("PE doesn't match. Expected %d got %d\n", expected_pe, my_pe); + abort(); + } + + if (num_pes != expected_n_pes) { + printf("Team size doesn't match. Expected %d got %d\n", expected_n_pes, num_pes); + abort(); + } + + __syncthreads(); + + rocshmem_ctx_quiet(ctx); + rocshmem_wg_ctx_destroy(&ctx); + rocshmem_wg_finalize(); + } + + __global__ void TeamCtxInfraTest(ShmemContextType ctx_type, rocshmem_team_t *team) { __shared__ rocshmem_ctx_t ctx1, ctx2, ctx3; __shared__ rocshmem_ctx_t ctx[NUM_TEAMS]; @@ -109,42 +137,105 @@ __global__ void TeamCtxInfraTest(ShmemContextType ctx_type, /****************************************************************************** * HOST TESTER CLASS METHODS *****************************************************************************/ -TeamCtxInfraTester::TeamCtxInfraTester(TesterArguments args) : Tester(args) {} +TeamCtxInfraTester::TeamCtxInfraTester(TesterArguments args) : Tester(args) { + _splitType = args.team_type; +} TeamCtxInfraTester::~TeamCtxInfraTester() {} void TeamCtxInfraTester::resetBuffers(size_t size) {} void TeamCtxInfraTester::preLaunchKernel() { - int n_pes = rocshmem_team_n_pes(ROCSHMEM_TEAM_WORLD); + int n_pes = rocshmem_team_n_pes(_parentTeam); + int my_pe = rocshmem_team_my_pe(_parentTeam); - // validate we can run the test - if (auto maximum_num_contexts_str = getenv("ROCSHMEM_MAX_NUM_CONTEXTS")) { - int max_ctx = atoi(maximum_num_contexts_str); - if (max_ctx <= NUM_TEAMS) { - printf("ROCSHMEM_MAX_NUM_CONTEXTS=%d is smaller than NUM_TEAMS %d, invalid test setup!\n", max_ctx, NUM_TEAMS); - assert(max_ctx > NUM_TEAMS); + if (_splitType == ROCSHMEM_TEST_TEAM_DUP) { + // validate we can run the test + if (auto maximum_num_contexts_str = getenv("ROCSHMEM_MAX_NUM_CONTEXTS")) { + int max_ctx = atoi(maximum_num_contexts_str); + if (max_ctx <= NUM_TEAMS) { + printf("ROCSHMEM_MAX_NUM_CONTEXTS=%d is smaller than NUM_TEAMS %d, invalid test setup!\n", max_ctx, NUM_TEAMS); + assert(max_ctx > NUM_TEAMS); + abort(); + } + } + + for (int team_i = 0; team_i < NUM_TEAMS; team_i++) { + team_world_dup[team_i] = ROCSHMEM_TEAM_INVALID; + rocshmem_team_split_strided(_parentTeam, 0, 1, n_pes, nullptr, 0, + &team_world_dup[team_i]); + if (team_world_dup[team_i] == ROCSHMEM_TEAM_INVALID) { + printf("Created team %d is invalid!\n", team_i); + abort(); + } + } + + /* Assert the failure of a new team creation. */ + rocshmem_team_t new_team = ROCSHMEM_TEAM_INVALID; + rocshmem_team_split_strided(_parentTeam, 0, 1, n_pes, nullptr, 0, + &new_team); + if (new_team != ROCSHMEM_TEAM_INVALID) { + printf("Created new team should have been invalid!\n"); abort(); } } + else if (_splitType == ROCSHMEM_TEST_TEAM_SINGLE) { + rocshmem_team_split_strided(_parentTeam, my_pe, 1, 1, nullptr, 0, + &team_world_dup[0]); + _expected_pe = rocshmem_team_my_pe(team_world_dup[0]); + _expected_n_pes = rocshmem_team_n_pes(team_world_dup[0]); - for (int team_i = 0; team_i < NUM_TEAMS; team_i++) { - team_world_dup[team_i] = ROCSHMEM_TEAM_INVALID; - rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0, - &team_world_dup[team_i]); - if (team_world_dup[team_i] == ROCSHMEM_TEAM_INVALID) { - printf("Created team %d is invalid!\n", team_i); + if (_expected_n_pes != 1) { + printf("ROCSHMEM_TEST_TEAM_SINGLE: n_pes %d expected: 1\n", _expected_n_pes); abort(); } - } - /* Assert the failure of a new team creation. */ - rocshmem_team_t new_team = ROCSHMEM_TEAM_INVALID; - rocshmem_team_split_strided(ROCSHMEM_TEAM_WORLD, 0, 1, n_pes, nullptr, 0, - &new_team); - if (new_team != ROCSHMEM_TEAM_INVALID) { - printf("Created new team should have been invalid!\n"); - abort(); + if (_expected_pe != 0) { + printf("ROCSHMEM_TEST_TEAM_SINGLE: my_pe %d expected: 0\n", _expected_pe); + abort(); + } + } else if (_splitType == ROCSHMEM_TEST_TEAM_BLOCK) { + int mid_pe = n_pes / 2; // integer division + int start_pe = my_pe < mid_pe ? 0 : mid_pe; + int end_pe = my_pe < mid_pe ? (mid_pe - 1) : (n_pes - 1); + int num_pes = end_pe - start_pe + 1; + int new_pe = my_pe < mid_pe ? my_pe : (my_pe - start_pe); + + rocshmem_team_split_strided(_parentTeam, start_pe, 1, num_pes, nullptr, 0, + &team_world_dup[0]); + _expected_pe = rocshmem_team_my_pe(team_world_dup[0]); + _expected_n_pes = rocshmem_team_n_pes(team_world_dup[0]); + + if (_expected_n_pes != num_pes) { + printf("ROCSHMEM_TEST_TEAM_BLOCK: n_pes %d expected: %d\n", _expected_n_pes, num_pes); + abort(); + } + + if (_expected_pe != new_pe) { + printf("ROCSHMEM_TEST_TEAM_BLOCK: my_pe %d expected: %d\n", _expected_pe, new_pe); + abort(); + } + } else if (_splitType == ROCSHMEM_TEST_TEAM_ODDEVEN) { + int start_pe = (my_pe % 2) == 0 ? 0 : 1; + int num_pes = n_pes / 2; + if (((n_pes % 2) != 0) && ((my_pe % 2) == 0)) + num_pes++; + int new_pe = (my_pe / 2); + + rocshmem_team_split_strided(_parentTeam, start_pe, 2, num_pes, nullptr, 0, + &team_world_dup[0]); + _expected_pe = rocshmem_team_my_pe(team_world_dup[0]); + _expected_n_pes = rocshmem_team_n_pes(team_world_dup[0]); + + if (_expected_n_pes != num_pes) { + printf("ROCSHMEM_TEST_TEAM_ODDEVEN: n_pes %d expected: %d\n", _expected_n_pes, num_pes); + abort(); + } + + if (_expected_pe != new_pe) { + printf("ROCSHMEM_TEST_TEAM_ODDEVEN: my_pe %d expected: %d\n", _expected_pe, new_pe); + abort(); + } } } @@ -154,18 +245,31 @@ void TeamCtxInfraTester::launchKernel(dim3 gridSize, dim3 blockSize, int loop, /* Copy array of teams to device */ rocshmem_team_t *teams_on_device; - CHECK_HIP(hipMalloc(&teams_on_device, sizeof(rocshmem_team_t) * NUM_TEAMS)); - CHECK_HIP(hipMemcpy(teams_on_device, team_world_dup, - sizeof(rocshmem_team_t) * NUM_TEAMS, hipMemcpyHostToDevice)); - hipLaunchKernelGGL(TeamCtxInfraTest, gridSize, blockSize, shared_bytes, - stream, _shmem_context, teams_on_device); + if (_splitType == ROCSHMEM_TEST_TEAM_DUP) { + CHECK_HIP(hipMalloc(&teams_on_device, sizeof(rocshmem_team_t) * NUM_TEAMS)); + CHECK_HIP(hipMemcpy(teams_on_device, team_world_dup, + sizeof(rocshmem_team_t) * NUM_TEAMS, hipMemcpyHostToDevice)); + + hipLaunchKernelGGL(TeamCtxInfraTest, gridSize, blockSize, shared_bytes, + stream, _shmem_context, teams_on_device); + } else if (_splitType == ROCSHMEM_TEST_TEAM_SINGLE || + _splitType == ROCSHMEM_TEST_TEAM_BLOCK || + _splitType == ROCSHMEM_TEST_TEAM_ODDEVEN ) { + CHECK_HIP(hipMalloc(&teams_on_device, sizeof(rocshmem_team_t))); + CHECK_HIP(hipMemcpy(teams_on_device, team_world_dup, + sizeof(rocshmem_team_t), hipMemcpyHostToDevice)); + + hipLaunchKernelGGL(TeamCtxInfraSimpleTest, gridSize, blockSize, shared_bytes, + stream, _shmem_context, teams_on_device[0], _expected_pe, _expected_n_pes); + } CHECK_HIP(hipFree(teams_on_device)); } void TeamCtxInfraTester::postLaunchKernel() { - for (int team_i = 0; team_i < NUM_TEAMS; team_i++) { + int num_teams = _splitType == ROCSHMEM_TEST_TEAM_DUP ? NUM_TEAMS : 1; + for (int team_i = 0; team_i < num_teams; team_i++) { rocshmem_team_destroy(team_world_dup[team_i]); } } diff --git a/projects/rocshmem/tests/functional_tests/team_ctx_infra_tester.hpp b/projects/rocshmem/tests/functional_tests/team_ctx_infra_tester.hpp index 2b108ad980..a0a2417695 100644 --- a/projects/rocshmem/tests/functional_tests/team_ctx_infra_tester.hpp +++ b/projects/rocshmem/tests/functional_tests/team_ctx_infra_tester.hpp @@ -49,6 +49,11 @@ class TeamCtxInfraTester : public Tester { char *s_buf = nullptr; char *r_buf = nullptr; + + TeamSplitType _splitType; + rocshmem::rocshmem_team_t _parentTeam = rocshmem::ROCSHMEM_TEAM_WORLD; + int _expected_pe; + int _expected_n_pes; }; #endif diff --git a/projects/rocshmem/tests/functional_tests/tester.cpp b/projects/rocshmem/tests/functional_tests/tester.cpp index 0d891dcc62..52c9c60bd0 100644 --- a/projects/rocshmem/tests/functional_tests/tester.cpp +++ b/projects/rocshmem/tests/functional_tests/tester.cpp @@ -150,6 +150,21 @@ std::vector Tester::create(TesterArguments args) { if (rank == 0) std::cout << "Team Ctx Infra test ###" << std::endl; testers.push_back(new TeamCtxInfraTester(args)); return testers; + case TeamCtxInfraTestSingleType: + if (rank == 0) std::cout << "Team Ctx Infra Single test ###" << std::endl; + args.team_type = ROCSHMEM_TEST_TEAM_SINGLE; + testers.push_back(new TeamCtxInfraTester(args)); + return testers; + case TeamCtxInfraTestBlockType: + if (rank == 0) std::cout << "Team Ctx Infra Block test ###" << std::endl; + args.team_type = ROCSHMEM_TEST_TEAM_BLOCK; + testers.push_back(new TeamCtxInfraTester(args)); + return testers; + case TeamCtxInfraTestOddEvenType: + if (rank == 0) std::cout << "Team Ctx Infra Odd-Even test ###" << std::endl; + args.team_type = ROCSHMEM_TEST_TEAM_ODDEVEN; + testers.push_back(new TeamCtxInfraTester(args)); + return testers; case TeamCtxGetTestType: if (rank == 0) std::cout << "Blocking Team Ctx Gets ###" << std::endl; testers.push_back(new TeamCtxPrimitiveTester(args)); @@ -527,7 +542,10 @@ void Tester::execute() { barrier(); - if (_type != TeamCtxInfraTestType) { + if (_type != TeamCtxInfraTestType && + _type != TeamCtxInfraTestSingleType && + _type != TeamCtxInfraTestBlockType && + _type != TeamCtxInfraTestOddEvenType ) { print(size); } } @@ -546,6 +564,8 @@ bool Tester::peLaunchesKernel() { */ is_launcher = is_launcher || (_type == TeamReductionTestType) || (_type == TeamBroadcastTestType) || (_type == TeamCtxInfraTestType) || + (_type == TeamCtxInfraTestSingleType) || (_type == TeamCtxInfraTestBlockType) || + (_type == TeamCtxInfraTestOddEvenType) || (_type == TeamAllToAllTestType) || (_type == TeamFCollectTestType) || (_type == PingPongTestType) || (_type == BarrierAllTestType) || (_type == WAVEBarrierAllTestType) || (_type == WGBarrierAllTestType) || diff --git a/projects/rocshmem/tests/functional_tests/tester.hpp b/projects/rocshmem/tests/functional_tests/tester.hpp index 877f8d22f7..ddb65c6508 100644 --- a/projects/rocshmem/tests/functional_tests/tester.hpp +++ b/projects/rocshmem/tests/functional_tests/tester.hpp @@ -110,6 +110,9 @@ enum TestType { TeamWAVEBarrierTestType = 70, WAVESyncTestType = 71, WGSyncTestType = 72, + TeamCtxInfraTestSingleType = 73, + TeamCtxInfraTestBlockType = 74, + TeamCtxInfraTestOddEvenType = 75, }; enum OpType { PutType = 0, GetType = 1 }; diff --git a/projects/rocshmem/tests/functional_tests/tester_arguments.cpp b/projects/rocshmem/tests/functional_tests/tester_arguments.cpp index 933a1056bc..2394947331 100644 --- a/projects/rocshmem/tests/functional_tests/tester_arguments.cpp +++ b/projects/rocshmem/tests/functional_tests/tester_arguments.cpp @@ -113,6 +113,9 @@ TesterArguments::TesterArguments(int argc, char *argv[]) { min_msg_size = 8; break; case TeamCtxInfraTestType: + case TeamCtxInfraTestSingleType: + case TeamCtxInfraTestBlockType: + case TeamCtxInfraTestOddEvenType: max_msg_size = min_msg_size; break; case PutNBIMRTestType: @@ -149,7 +152,8 @@ void TesterArguments::get_rocshmem_arguments() { (type != TeamFCollectTestType) && (type != TeamReductionTestType) && (type != TeamBroadcastTestType) && (type != PingAllTestType) && (type != TeamBarrierTestType) && (type != TeamWAVEBarrierTestType) && - (type != TeamWGBarrierTestType)) { + (type != TeamWGBarrierTestType) && (type != TeamCtxInfraTestBlockType) && + (type != TeamCtxInfraTestOddEvenType)) { if (numprocs != 2) { if (myid == 0) { std::cerr << "This test requires exactly two processes, we have " diff --git a/projects/rocshmem/tests/functional_tests/tester_arguments.hpp b/projects/rocshmem/tests/functional_tests/tester_arguments.hpp index 61478fb1e6..217e54d50f 100644 --- a/projects/rocshmem/tests/functional_tests/tester_arguments.hpp +++ b/projects/rocshmem/tests/functional_tests/tester_arguments.hpp @@ -31,6 +31,14 @@ #include #include + +enum TeamSplitType { + ROCSHMEM_TEST_TEAM_DUP = 0, // Dup parent team + ROCSHMEM_TEST_TEAM_SINGLE, // each PE will be its own team + ROCSHMEM_TEST_TEAM_BLOCK, // split parent into two halfs + ROCSHMEM_TEST_TEAM_ODDEVEN, // odd-even splitting +}; + class TesterArguments { public: TesterArguments(int argc, char *argv[]); @@ -47,7 +55,7 @@ class TesterArguments { */ static void show_usage(std::string executable_name); - public: +public: /** * Arguments obtained from command line */ @@ -75,6 +83,8 @@ class TesterArguments { int skip = 10; int loop_large = 10; size_t large_message_size = 32768; + + TeamSplitType team_type = ROCSHMEM_TEST_TEAM_DUP; }; #endif