diff --git a/rocclr/runtime/device/pal/palkernel.cpp b/rocclr/runtime/device/pal/palkernel.cpp index f35ac18b7a..20c06046a8 100644 --- a/rocclr/runtime/device/pal/palkernel.cpp +++ b/rocclr/runtime/device/pal/palkernel.cpp @@ -716,7 +716,10 @@ bool HSAILKernel::init(amd::hsa::loader::Symbol* sym, bool finalize) { workGroupInfo_.size_ = workGroupInfo_.compileSize_[0] * workGroupInfo_.compileSize_[1] * workGroupInfo_.compileSize_[2]; } else { - workGroupInfo_.size_ = dev().info().preferredWorkGroupSize_; + size_t nItems = (workGroupInfo_.availableVGPRs_ / workGroupInfo_.usedVGPRs_) * + dev().hwInfo()->simdPerCU_ * workGroupInfo_.wavefrontSize_; + workGroupInfo_.size_ = nItems > dev().info().preferredWorkGroupSize_ ? + std::min(size_t(1024) , nItems) : dev().info().preferredWorkGroupSize_; } // Pull out printf metadata from the ELF @@ -1435,7 +1438,12 @@ bool LightningKernel::init(amd::hsa::loader::Symbol* symbol) { // Copy wavefront size workGroupInfo_.wavefrontSize_ = dev().info().wavefrontWidth_; - workGroupInfo_.size_ = kernelMD->mCodeProps.mMaxFlatWorkGroupSize; + + size_t nItems = (workGroupInfo_.availableVGPRs_ / workGroupInfo_.usedVGPRs_) * + dev().hwInfo()->simdPerCU_ * workGroupInfo_.wavefrontSize_; + workGroupInfo_.size_ = nItems > kernelMD->mCodeProps.mMaxFlatWorkGroupSize ? + std::min(size_t(1024), nItems) : kernelMD->mCodeProps.mMaxFlatWorkGroupSize; + if (workGroupInfo_.size_ == 0) { return false; }