diff --git a/wddm/queue.cpp b/wddm/queue.cpp index c10a034705..44658819cb 100644 --- a/wddm/queue.cpp +++ b/wddm/queue.cpp @@ -499,12 +499,12 @@ bool ComputeQueue::UpdateScratch(hsa_kernel_dispatch_packet_t *packet, bool wave const uint64_t size_per_thread = AlignUp(packet->private_segment_size, scratch_mem_alignment_size_ / lanes_per_wave); - scratch_size_per_wave_ = size_per_thread * lanes_per_wave; - uint64_t groups = CalcDispatchGroups(packet); uint64_t waves_per_group = CalcDispatchWavesPerGroup(packet, wave32); - dispatch_waves_ = groups * waves_per_group; + // For packet batching, the maximum value must be used to fit all packets. + scratch_size_per_wave_ = std::max(size_per_thread * lanes_per_wave, scratch_size_per_wave_); + dispatch_waves_ = std::max(groups * waves_per_group, dispatch_waves_); const uint64_t max_scratch_size = scratch_size_per_wave_ * max_scratch_waves_; const uint64_t dispatch_size = scratch_size_per_wave_ * dispatch_waves_;