diff --git a/projects/clr/rocclr/device/device.cpp b/projects/clr/rocclr/device/device.cpp index e7c86be86d..ac8ed6ce19 100644 --- a/projects/clr/rocclr/device/device.cpp +++ b/projects/clr/rocclr/device/device.cpp @@ -737,8 +737,7 @@ bool Device::disableP2P(amd::Device* ptrDev) { } bool Device::UpdateStackSize(uint64_t stackSize) { - uint32_t maxMemPerThread = info().localMemSizePerCU_ / info().maxThreadsPerCU_; - if (maxMemPerThread < stackSize) { + if (stackSize > 16 * Ki) { return false; } stack_size_ = stackSize; diff --git a/projects/clr/rocclr/device/device.hpp b/projects/clr/rocclr/device/device.hpp index 0be48a7ad6..905f0705e9 100644 --- a/projects/clr/rocclr/device/device.hpp +++ b/projects/clr/rocclr/device/device.hpp @@ -1972,8 +1972,8 @@ class Device : public RuntimeObject { std::once_flag heap_initialized_; //!< Heap buffer initialization flag device::Memory* heap_buffer_; //!< Preallocated heap buffer for memory allocations on device - amd::Memory* arena_mem_obj_; //!< Arena memory object - uint64_t stack_size_{0}; //!< Device stack size + amd::Memory* arena_mem_obj_; //!< Arena memory object + uint64_t stack_size_{1024}; //!< Device stack size private: const Isa *isa_; //!< Device isa diff --git a/projects/clr/rocclr/device/rocm/rocvirtual.cpp b/projects/clr/rocclr/device/rocm/rocvirtual.cpp index 2e4b84981c..63b9637856 100644 --- a/projects/clr/rocclr/device/rocm/rocvirtual.cpp +++ b/projects/clr/rocclr/device/rocm/rocvirtual.cpp @@ -2972,10 +2972,10 @@ bool VirtualGPU::submitKernelInternal(const amd::NDRangeContainer& sizes, dispatchPacket.private_segment_size = devKernel->workGroupInfo()->privateMemSize_; if ((devKernel->workGroupInfo()->usedStackSize_ & 0x1) == 0x1) { - dispatchPacket.private_segment_size += dev().StackSize(); - uint32_t maxMemPerThread = device().info().localMemSizePerCU_ / device().info().maxThreadsPerCU_; - if (dispatchPacket.private_segment_size > maxMemPerThread) { - dispatchPacket.private_segment_size = maxMemPerThread; + dispatchPacket.private_segment_size = + std::max(dev().StackSize(), dispatchPacket.private_segment_size); + if (dispatchPacket.private_segment_size > 16 * Ki) { + dispatchPacket.private_segment_size = 16 * Ki; } }