diff --git a/runtime/hsa-runtime/core/runtime/amd_aql_queue.cpp b/runtime/hsa-runtime/core/runtime/amd_aql_queue.cpp index 6b89400f46..e66ac355a1 100644 --- a/runtime/hsa-runtime/core/runtime/amd_aql_queue.cpp +++ b/runtime/hsa-runtime/core/runtime/amd_aql_queue.cpp @@ -831,8 +831,11 @@ bool AqlQueue::DynamicScratchHandler(hsa_signal_value_t error_code, void* arg) { assert((scratch_request != 0) && "Scratch memory request from packet with no scratch demand. Possible bad kernel code object."); + // Get the hw maximum scratch slot count taking into consideration asymmetric harvest. + const uint32_t engines = queue->agent_->properties().NumShaderBanks; + const uint32_t cu_count = queue->amd_queue_.max_cu_id + 1; const uint32_t MaxScratchSlots = - (queue->amd_queue_.max_cu_id + 1) * queue->agent_->properties().MaxSlotsScratchCU; + AlignUp(cu_count, engines) * queue->agent_->properties().MaxSlotsScratchCU; scratch.size_per_thread = scratch_request; scratch.lanes_per_wave = (error_code & 0x400) ? 32 : 64; @@ -840,6 +843,9 @@ bool AqlQueue::DynamicScratchHandler(hsa_signal_value_t error_code, void* arg) { scratch.size_per_thread = AlignUp(scratch.size_per_thread, 1024 / scratch.lanes_per_wave); scratch.size = scratch.size_per_thread * MaxScratchSlots * scratch.lanes_per_wave; + // Smaller dispatches may not need to reach full device occupancy. + // For these we need to ensure that the scratch we give doesn't restrict the dispatch even + // though it does not fill the device. Figure the total requested dispatch size. uint64_t lanes_per_group = (uint64_t(pkt.dispatch.workgroup_size_x) * pkt.dispatch.workgroup_size_y) * pkt.dispatch.workgroup_size_z; @@ -854,9 +860,17 @@ bool AqlQueue::DynamicScratchHandler(hsa_signal_value_t error_code, void* arg) { ((uint64_t(pkt.dispatch.grid_size_z) + pkt.dispatch.workgroup_size_z - 1) / pkt.dispatch.workgroup_size_z); - // Assign an equal number of groups to each engine, clipping to capacity limits - const uint32_t engines = queue->agent_->properties().NumShaderBanks; - groups = ((groups + engines - 1) / engines) * engines; + // Find the maximum number of groups assigned to any engine. + const uint32_t symmetric_cus = AlignDown(cu_count, engines); + const uint32_t asymmetryPerRound = cu_count - symmetric_cus; + const uint64_t rounds = groups / cu_count; + const uint64_t asymmetricGroups = rounds * asymmetryPerRound; + const uint64_t symmetricGroups = groups - asymmetricGroups; + const uint64_t maxGroupsPerEngine = + ((symmetricGroups + engines - 1) / engines) + (asymmetryPerRound ? rounds : 0); + + // Populate all engines at max group occupancy, then clip down to device limits. + groups = maxGroupsPerEngine * engines; scratch.wanted_slots = groups * waves_per_group; scratch.wanted_slots = Min(scratch.wanted_slots, uint64_t(MaxScratchSlots)); scratch.dispatch_size = diff --git a/runtime/hsa-runtime/core/util/utils.h b/runtime/hsa-runtime/core/util/utils.h index b1901b6520..ab536ba796 100644 --- a/runtime/hsa-runtime/core/util/utils.h +++ b/runtime/hsa-runtime/core/util/utils.h @@ -237,8 +237,7 @@ static __forceinline bool IsPowerOfTwo(T val) { /// @return: T. template static __forceinline T AlignDown(T value, size_t alignment) { - assert(IsPowerOfTwo(alignment)); - return (T)(value & ~(alignment - 1)); + return (T)((value / alignment) * alignment); } /// @brief: Same as previous one, but first parameter becomes pointer, for more