diff --git a/projects/clr/rocclr/device/device.cpp b/projects/clr/rocclr/device/device.cpp index 2ba0e7727f..fe59d216de 100644 --- a/projects/clr/rocclr/device/device.cpp +++ b/projects/clr/rocclr/device/device.cpp @@ -909,7 +909,7 @@ bool Device::UpdateStackSize(uint64_t stackSize) { if (stackSize > kStackSize) { return false; } - stack_size_ = stackSize; + stack_size_ = amd::alignUp(stackSize, 16); return true; } diff --git a/projects/clr/rocclr/device/pal/paldevice.cpp b/projects/clr/rocclr/device/pal/paldevice.cpp index 42c1f25cbf..a0de81959a 100644 --- a/projects/clr/rocclr/device/pal/paldevice.cpp +++ b/projects/clr/rocclr/device/pal/paldevice.cpp @@ -2277,7 +2277,8 @@ bool Device::validateKernel(const amd::Kernel& kernel, const device::VirtualDevi bool coop_groups) { // Find the number of scratch registers used in the kernel const device::Kernel* devKernel = kernel.getDeviceKernel(*this); - uint regNum = static_cast(devKernel->workGroupInfo()->scratchRegs_); + uint32_t regNum = static_cast(devKernel->workGroupInfo()->scratchRegs_); + regNum = std::max(static_cast(stack_size_) / sizeof(uint32_t), regNum); const VirtualGPU* vgpu = static_cast(vdev); if (!allocScratch(regNum, vgpu, devKernel->workGroupInfo()->usedVGPRs_)) { diff --git a/projects/clr/rocclr/device/pal/palvirtual.cpp b/projects/clr/rocclr/device/pal/palvirtual.cpp index ba3f4fc727..5511e42122 100644 --- a/projects/clr/rocclr/device/pal/palvirtual.cpp +++ b/projects/clr/rocclr/device/pal/palvirtual.cpp @@ -2685,16 +2685,24 @@ bool VirtualGPU::submitKernelInternal(const amd::NDRangeContainer& sizes, LogError("Couldn't load kernel arguments"); return false; } + // Dynamic call stack size is considered to calculate private segment size and scratch regs + // in LightningKernel::postLoad(). As it is not called during hipModuleLaunchKernel unlike + // hipLaunchKernel/hipLaunchKernelGGL, Updated value is passed to dispatch packet. + size_t privateMemSize = hsaKernel.spillSegSize(); + if ((hsaKernel.workGroupInfo()->usedStackSize_ & 0x1) == 0x1) { + privateMemSize = std::max(static_cast(device().StackSize()), + hsaKernel.workGroupInfo()->scratchRegs_ * sizeof(uint32_t)) ; + } // Set up the dispatch information Pal::DispatchAqlParams dispatchParam = {}; dispatchParam.pAqlPacket = aqlPkt; - if (hsaKernel.workGroupInfo()->scratchRegs_ > 0) { + if (privateMemSize > 0) { const Device::ScratchBuffer* scratch = dev().scratch(hwRing()); dispatchParam.scratchAddr = scratch->memObj_->vmAddress(); dispatchParam.scratchSize = scratch->size_; dispatchParam.scratchOffset = scratch->offset_; - dispatchParam.workitemPrivateSegmentSize = hsaKernel.spillSegSize(); + dispatchParam.workitemPrivateSegmentSize = privateMemSize; } dispatchParam.pCpuAqlCode = hsaKernel.cpuAqlCode(); dispatchParam.hsaQueueVa = hsaQueueMem_->vmAddress();