wsl/librocdxg: correct scratch info for kernel dispatch

The scratch_size_per_wave_ and dispatch_waves_ should use
the maximum values from all packets in the batch.

Signed-off-by: Longlong Yao <Longlong.Yao@amd.com>
Reviewed-by: Flora Cui <flora.cui@amd.com>
This commit is contained in:
Longlong Yao
2026-01-07 15:10:22 +08:00
committed by Flora Cui
orang tua c3f55c8e59
melakukan e67113a741
+3 -3
Melihat File
@@ -499,12 +499,12 @@ bool ComputeQueue::UpdateScratch(hsa_kernel_dispatch_packet_t *packet, bool wave
const uint64_t size_per_thread = AlignUp(packet->private_segment_size,
scratch_mem_alignment_size_ / lanes_per_wave);
scratch_size_per_wave_ = size_per_thread * lanes_per_wave;
uint64_t groups = CalcDispatchGroups(packet);
uint64_t waves_per_group = CalcDispatchWavesPerGroup(packet, wave32);
dispatch_waves_ = groups * waves_per_group;
// For packet batching, the maximum value must be used to fit all packets.
scratch_size_per_wave_ = std::max(size_per_thread * lanes_per_wave, scratch_size_per_wave_);
dispatch_waves_ = std::max(groups * waves_per_group, dispatch_waves_);
const uint64_t max_scratch_size = scratch_size_per_wave_ * max_scratch_waves_;
const uint64_t dispatch_size = scratch_size_per_wave_ * dispatch_waves_;