diff --git a/projects/clr/rocclr/runtime/device/pal/palkernel.cpp b/projects/clr/rocclr/runtime/device/pal/palkernel.cpp index 7a4823ddaa..de15c52a93 100644 --- a/projects/clr/rocclr/runtime/device/pal/palkernel.cpp +++ b/projects/clr/rocclr/runtime/device/pal/palkernel.cpp @@ -29,8 +29,13 @@ namespace pal { void HSAILKernel::setWorkGroupInfo(const uint32_t privateSegmentSize, const uint32_t groupSegmentSize, const uint16_t numSGPRs, const uint16_t numVGPRs) { - workGroupInfo_.scratchRegs_ = amd::alignUp(privateSegmentSize, 16) / sizeof(uint); - workGroupInfo_.privateMemSize_ = privateSegmentSize; + workGroupInfo_.scratchRegs_ = amd::alignUp(privateSegmentSize, 16) / sizeof(uint32_t); + // Make sure runtime matches HW alignment, which is 256 scratch regs (DWORDs) per wave + constexpr uint32_t ScratchRegAlignment = 256; + workGroupInfo_.scratchRegs_ = + amd::alignUp((workGroupInfo_.scratchRegs_ * dev().info().wavefrontWidth_), + ScratchRegAlignment) / dev().info().wavefrontWidth_; + workGroupInfo_.privateMemSize_ = workGroupInfo_.scratchRegs_ * sizeof(uint32_t); workGroupInfo_.localMemSize_ = workGroupInfo_.usedLDSSize_ = groupSegmentSize; workGroupInfo_.usedSGPRs_ = numSGPRs; workGroupInfo_.usedStackSize_ = 0;