From e67113a7413fa2b120c4600680b170e6f4d77e36 Mon Sep 17 00:00:00 2001 From: Longlong Yao Date: Wed, 7 Jan 2026 15:10:22 +0800 Subject: [PATCH] wsl/librocdxg: correct scratch info for kernel dispatch The scratch_size_per_wave_ and dispatch_waves_ should use the maximum values from all packets in the batch. Signed-off-by: Longlong Yao Reviewed-by: Flora Cui --- wddm/queue.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/wddm/queue.cpp b/wddm/queue.cpp index c10a034705..44658819cb 100644 --- a/wddm/queue.cpp +++ b/wddm/queue.cpp @@ -499,12 +499,12 @@ bool ComputeQueue::UpdateScratch(hsa_kernel_dispatch_packet_t *packet, bool wave const uint64_t size_per_thread = AlignUp(packet->private_segment_size, scratch_mem_alignment_size_ / lanes_per_wave); - scratch_size_per_wave_ = size_per_thread * lanes_per_wave; - uint64_t groups = CalcDispatchGroups(packet); uint64_t waves_per_group = CalcDispatchWavesPerGroup(packet, wave32); - dispatch_waves_ = groups * waves_per_group; + // For packet batching, the maximum value must be used to fit all packets. + scratch_size_per_wave_ = std::max(size_per_thread * lanes_per_wave, scratch_size_per_wave_); + dispatch_waves_ = std::max(groups * waves_per_group, dispatch_waves_); const uint64_t max_scratch_size = scratch_size_per_wave_ * max_scratch_waves_; const uint64_t dispatch_size = scratch_size_per_wave_ * dispatch_waves_;