diff --git a/runtime/hsa-runtime/core/inc/scratch_cache.h b/runtime/hsa-runtime/core/inc/scratch_cache.h index 5029f4781d..e85e22660f 100644 --- a/runtime/hsa-runtime/core/inc/scratch_cache.h +++ b/runtime/hsa-runtime/core/inc/scratch_cache.h @@ -93,6 +93,7 @@ class ScratchCache { size_t dispatch_size; size_t size_per_thread; uint32_t lanes_per_wave; + uint32_t waves_per_group; ptrdiff_t queue_process_offset; bool large; bool retry; diff --git a/runtime/hsa-runtime/core/runtime/amd_aql_queue.cpp b/runtime/hsa-runtime/core/runtime/amd_aql_queue.cpp index 03d5c9285f..466afe2beb 100644 --- a/runtime/hsa-runtime/core/runtime/amd_aql_queue.cpp +++ b/runtime/hsa-runtime/core/runtime/amd_aql_queue.cpp @@ -798,6 +798,8 @@ bool AqlQueue::DynamicScratchHandler(hsa_signal_value_t error_code, void* arg) { pkt.dispatch.workgroup_size_z; uint64_t waves_per_group = (lanes_per_group + scratch.lanes_per_wave - 1) / scratch.lanes_per_wave; + scratch.waves_per_group = waves_per_group; + uint64_t groups = ((uint64_t(pkt.dispatch.grid_size_x) + pkt.dispatch.workgroup_size_x - 1) / pkt.dispatch.workgroup_size_x) * ((uint64_t(pkt.dispatch.grid_size_y) + pkt.dispatch.workgroup_size_y - 1) / @@ -805,10 +807,11 @@ bool AqlQueue::DynamicScratchHandler(hsa_signal_value_t error_code, void* arg) { ((uint64_t(pkt.dispatch.grid_size_z) + pkt.dispatch.workgroup_size_z - 1) / pkt.dispatch.workgroup_size_z); + // Assign an equal number of groups to each engine, clipping to capacity limits + const uint32_t engines = queue->agent_->properties().NumShaderBanks; + groups = ((groups + engines - 1) / engines) * engines; scratch.wanted_slots = groups * waves_per_group; scratch.wanted_slots = Min(scratch.wanted_slots, uint64_t(MaxScratchSlots)); - scratch.wanted_slots = - Max(scratch.wanted_slots, uint64_t(queue->agent_->properties().NumShaderBanks)); scratch.dispatch_size = scratch.size_per_thread * scratch.wanted_slots * scratch.lanes_per_wave; diff --git a/runtime/hsa-runtime/core/runtime/amd_gpu_agent.cpp b/runtime/hsa-runtime/core/runtime/amd_gpu_agent.cpp index ba8712273a..a6c5586d8f 100644 --- a/runtime/hsa-runtime/core/runtime/amd_gpu_agent.cpp +++ b/runtime/hsa-runtime/core/runtime/amd_gpu_agent.cpp @@ -1185,7 +1185,7 @@ void GpuAgent::AcquireQueueScratch(ScratchInfo& scratch) { return; } scratch_pool_.free(base); - waves_per_cu--; + waves_per_cu = waves_per_cu - scratch.waves_per_group; } // Failed to allocate minimal scratch