SWDEV-567545 - Implement block_rank in co-op grid groups (#2182)

* SWDEV-567545 - Implement block_rank in co-op grid groups
Bu işleme şunda yer alıyor:
Jimbo
2025-12-29 11:39:23 -05:00
işlemeyi yapan: GitHub
ebeveyn 5bf6e366dd
işleme a59d46ffbf
6 değiştirilmiş dosya ile 59 ekleme ve 1 silme
+3
Dosyayı Görüntüle
@@ -9,6 +9,9 @@ Full documentation for HIP is available at [rocm.docs.amd.com](https://rocm.docs
* New HIP APIs
- `hipKernelGetParamInfo` returns the offset and size of a kernel parameter
* New HIP supports
- `grid_group::block_rank()` returns the rank of the block in the calling thread
## HIP 7.2 for ROCm 7.2
### Added
+8 -1
Dosyayı Görüntüle
@@ -96,6 +96,8 @@ class thread_group {
__CG_QUALIFIER__ unsigned int cg_type() const { return _type; }
//! Rank of the calling thread within [0, \link num_threads() num_threads() \endlink).
__CG_QUALIFIER__ __hip_uint32_t thread_rank() const;
//! Rank of the block in calling thread within [0, \link num_threads() num_threads() \endlink).
__CG_QUALIFIER__ __hip_uint32_t block_rank() const;
//! Returns true if the group has not violated any API constraints.
__CG_QUALIFIER__ bool is_valid() const;
@@ -203,6 +205,8 @@ class grid_group : public thread_group {
public:
//! @copydoc thread_group::thread_rank
__CG_QUALIFIER__ __hip_uint32_t thread_rank() const { return internal::grid::thread_rank(); }
//! @copydoc thread_group::block_rank
__CG_QUALIFIER__ __hip_uint32_t block_rank() const { return internal::grid::block_rank(); }
//! @copydoc thread_group::is_valid
__CG_QUALIFIER__ bool is_valid() const { return internal::grid::is_valid(); }
//! @copydoc thread_group::sync
@@ -275,6 +279,10 @@ class thread_block : public thread_group {
__CG_STATIC_QUALIFIER__ __hip_uint32_t thread_rank() {
return internal::workgroup::thread_rank();
}
//! @copydoc thread_group::block_rank
__CG_STATIC_QUALIFIER__ __hip_uint32_t block_rank() {
return internal::workgroup::block_rank();
}
//! @copydoc thread_group::num_threads
__CG_STATIC_QUALIFIER__ __hip_uint32_t num_threads() {
return internal::workgroup::num_threads();
@@ -353,7 +361,6 @@ class tiled_group : public thread_group {
__CG_QUALIFIER__ unsigned int thread_rank() const {
return (internal::workgroup::thread_rank() & (coalesced_info.tiled_info.num_threads - 1));
}
//! @copydoc thread_group::sync
__CG_QUALIFIER__ void sync() const { internal::tiled_group::sync(); }
};
+10
Dosyayı Görüntüle
@@ -182,6 +182,11 @@ __CG_STATIC_QUALIFIER__ __hip_uint32_t thread_rank() {
return (num_threads_till_current_workgroup + local_thread_rank);
}
__CG_STATIC_QUALIFIER__ __hip_uint32_t block_rank() {
return static_cast<__hip_uint32_t>((blockIdx.z * gridDim.y * gridDim.x) +
(blockIdx.y * gridDim.x) + (blockIdx.x));
}
__CG_STATIC_QUALIFIER__ bool is_valid() { return static_cast<bool>(__ockl_grid_is_valid()); }
__CG_STATIC_QUALIFIER__ void sync() { __ockl_grid_sync(); }
@@ -219,6 +224,11 @@ __CG_STATIC_QUALIFIER__ __hip_uint32_t thread_rank() {
(threadIdx.y * blockDim.x) + (threadIdx.x)));
}
__CG_STATIC_QUALIFIER__ __hip_uint32_t block_rank() {
return (static_cast<__hip_uint32_t>((blockIdx.z * gridDim.x * gridDim.y) +
(blockIdx.y * gridDim.x) + (blockIdx.x)));
}
__CG_STATIC_QUALIFIER__ bool is_valid() { return true; }
__CG_STATIC_QUALIFIER__ void sync() { __syncthreads(); }
+9
Dosyayı Görüntüle
@@ -43,6 +43,15 @@ struct CPUGrid {
return thread_rank_in_grid % threads_in_block_count_;
}
inline std::optional<unsigned int> block_rank_in_grid(
const unsigned int thread_rank_in_grid) const {
if (thread_rank_in_grid > thread_count_) {
return std::nullopt;
}
return thread_rank_in_grid / threads_in_block_count_;
}
inline std::optional<dim3> block_idx(const unsigned int thread_rank_in_grid) const {
if (thread_rank_in_grid > thread_count_) {
return std::nullopt;
+13
Dosyayı Görüntüle
@@ -39,6 +39,10 @@ static __global__ void grid_group_thread_rank_getter(unsigned int* thread_ranks)
thread_ranks[thread_rank_in_grid()] = cg::this_grid().thread_rank();
}
static __global__ void grid_group_block_rank_getter(unsigned int* block_ranks) {
block_ranks[thread_rank_in_grid()] = cg::this_grid().block_rank();
}
static __global__ void grid_group_is_valid_getter(unsigned int* is_valid_flags) {
is_valid_flags[thread_rank_in_grid()] = cg::this_grid().is_valid();
}
@@ -160,9 +164,18 @@ TEST_CASE("Unit_Grid_Group_Getters_Positive_Basic") {
HIP_CHECK(hipMemcpy(uint_arr.ptr(), uint_arr_dev.ptr(),
grid.thread_count_ * sizeof(*uint_arr.ptr()), hipMemcpyDeviceToHost));
HIP_CHECK(hipDeviceSynchronize());
HIP_CHECK(hipLaunchCooperativeKernel(grid_group_block_rank_getter, blocks, threads, params, 0, 0));
// Verify grid_group.is_valid() values
ArrayAllOf(uint_arr.ptr(), grid.thread_count_, [](uint32_t) { return 1; });
HIP_CHECK(hipMemcpy(uint_arr.ptr(), uint_arr_dev.ptr(),
grid.thread_count_ * sizeof(*uint_arr.ptr()), hipMemcpyDeviceToHost));
HIP_CHECK(hipDeviceSynchronize());
// Verify grid_group.block_rank() values
ArrayAllOf(uint_arr.ptr(), grid.thread_count_, [threads](uint32_t i) {
return i/(threads.x * threads.y * threads.z); });
}
/**
+16
Dosyayı Görüntüle
@@ -49,6 +49,12 @@ static __global__ void thread_block_thread_rank_getter(unsigned int* thread_rank
thread_ranks[thread_rank_in_grid()] = group.thread_rank();
}
template <typename BaseType = cg::thread_block>
static __global__ void thread_block_block_rank_getter(unsigned int* block_ranks) {
const BaseType group = cg::this_thread_block();
block_ranks[thread_rank_in_grid()] = group.block_rank();
}
static __global__ void thread_block_group_indices_getter(dim3* group_indices) {
group_indices[thread_rank_in_grid()] = cg::this_thread_block().group_index();
}
@@ -111,10 +117,20 @@ TEST_CASE("Unit_Thread_Block_Getters_Positive_Basic") {
HIP_CHECK(hipMemcpy(uint_arr.ptr(), uint_arr_dev.ptr(),
grid.thread_count_ * sizeof(*uint_arr.ptr()), hipMemcpyDeviceToHost));
HIP_CHECK(hipDeviceSynchronize());
thread_block_block_rank_getter<<<blocks, threads>>>(uint_arr_dev.ptr());
HIP_CHECK(hipGetLastError());
// Validate thread_block.thread_rank() values
ArrayAllOf(uint_arr.ptr(), grid.thread_count_,
[&grid](uint32_t i) { return grid.thread_rank_in_block(i).value(); });
HIP_CHECK(hipMemcpy(uint_arr.ptr(), uint_arr_dev.ptr(),
grid.thread_count_ * sizeof(*uint_arr.ptr()), hipMemcpyDeviceToHost));
HIP_CHECK(hipDeviceSynchronize());
// Validate thread_block.block_rank() values
ArrayAllOf(uint_arr.ptr(), grid.thread_count_,
[&grid](uint32_t i) { return grid.block_rank_in_grid(i).value(); });
}
{