diff --git a/runtime/hsa-runtime/core/inc/scratch_cache.h b/runtime/hsa-runtime/core/inc/scratch_cache.h index 9ab50edbf5..a3a4e20250 100644 --- a/runtime/hsa-runtime/core/inc/scratch_cache.h +++ b/runtime/hsa-runtime/core/inc/scratch_cache.h @@ -95,6 +95,7 @@ class ScratchCache { uint32_t lanes_per_wave; uint32_t waves_per_group; uint64_t wanted_slots; + uint32_t mem_alignment_size; bool cooperative; ptrdiff_t queue_process_offset; bool large; diff --git a/runtime/hsa-runtime/core/runtime/amd_aql_queue.cpp b/runtime/hsa-runtime/core/runtime/amd_aql_queue.cpp index 245f4d3032..921c46a3c9 100644 --- a/runtime/hsa-runtime/core/runtime/amd_aql_queue.cpp +++ b/runtime/hsa-runtime/core/runtime/amd_aql_queue.cpp @@ -213,6 +213,11 @@ AqlQueue::AqlQueue(GpuAgent* agent, size_t req_size_pkts, HSAuint32 node_id, Scr assert(amd_queue_.private_segment_aperture_base_hi != 0 && "No private region found."); } + if (agent_->isa()->GetMajorVersion() >= 11) + queue_scratch_.mem_alignment_size = 256; + else + queue_scratch_.mem_alignment_size = 1024; + MAKE_NAMED_SCOPE_GUARD(EventGuard, [&]() { ScopedAcquire _lock(&queue_lock_); queue_count_--; @@ -839,8 +844,10 @@ bool AqlQueue::DynamicScratchHandler(hsa_signal_value_t error_code, void* arg) { scratch.size_per_thread = scratch_request; scratch.lanes_per_wave = (error_code & 0x400) ? 32 : 64; - // Align whole waves to 1KB. - scratch.size_per_thread = AlignUp(scratch.size_per_thread, 1024 / scratch.lanes_per_wave); + + scratch.size_per_thread = + AlignUp(scratch.size_per_thread, scratch.mem_alignment_size / scratch.lanes_per_wave); + scratch.size = scratch.size_per_thread * MaxScratchSlots * scratch.lanes_per_wave; // Smaller dispatches may not need to reach full device occupancy. @@ -1389,11 +1396,13 @@ void AqlQueue::FillComputeTmpRingSize() { // Scratch is allocated program COMPUTE_TMPRING_SIZE register // Scratch Size per Wave is specified in terms of kilobytes - uint32_t wave_scratch = (((queue_scratch_.lanes_per_wave * - queue_scratch_.size_per_thread) + 1023) / 1024); + uint32_t wave_scratch = (((queue_scratch_.lanes_per_wave * queue_scratch_.size_per_thread) + + queue_scratch_.mem_alignment_size - 1) / + queue_scratch_.mem_alignment_size); tmpring_size.bits.WAVESIZE = wave_scratch; assert(wave_scratch == tmpring_size.bits.WAVESIZE && "WAVESIZE Overflow."); - uint32_t num_waves = queue_scratch_.size / (tmpring_size.bits.WAVESIZE * 1024); + uint32_t num_waves = + queue_scratch_.size / (tmpring_size.bits.WAVESIZE * queue_scratch_.mem_alignment_size); tmpring_size.bits.WAVES = std::min(num_waves, max_scratch_waves); amd_queue_.compute_tmpring_size = tmpring_size.u32All; assert((tmpring_size.bits.WAVES % agent_props.NumShaderBanks == 0) && @@ -1415,11 +1424,15 @@ void AqlQueue::FillComputeTmpRingSize_Gfx11() { // Scratch is allocated program COMPUTE_TMPRING_SIZE register // Scratch Size per Wave is specified in terms of kilobytes - uint32_t wave_scratch = - (((queue_scratch_.lanes_per_wave * queue_scratch_.size_per_thread) + 255) / 256); + uint32_t wave_scratch = (((queue_scratch_.lanes_per_wave * queue_scratch_.size_per_thread) + + queue_scratch_.mem_alignment_size - 1) / + queue_scratch_.mem_alignment_size); + tmpring_size.bits.WAVESIZE = wave_scratch; assert(wave_scratch == tmpring_size.bits.WAVESIZE && "WAVESIZE Overflow."); - uint32_t num_waves = queue_scratch_.size / (tmpring_size.bits.WAVESIZE * 256); + + uint32_t num_waves = + queue_scratch_.size / (tmpring_size.bits.WAVESIZE * queue_scratch_.mem_alignment_size); // For GFX11 we specify number of waves per engine instead of total num_waves /= agent_->properties().NumShaderBanks;