From fcb0a3d4e2a1863b6d1deeb1b30a0175a8152017 Mon Sep 17 00:00:00 2001 From: Jatin Chaudhary <51944368+cjatin@users.noreply.github.com> Date: Mon, 8 Jul 2019 16:30:38 +0530 Subject: [PATCH] Adding bounds check before hipMemset (#1190) * Adding bounds check in ihipMemset * Adding ihipMemPtrGetInfo to hipMemPtrGetInfo --- hipamd/src/hip_memory.cpp | 51 +++++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 21 deletions(-) diff --git a/hipamd/src/hip_memory.cpp b/hipamd/src/hip_memory.cpp index b85ba61584..720d04deb8 100644 --- a/hipamd/src/hip_memory.cpp +++ b/hipamd/src/hip_memory.cpp @@ -1506,6 +1506,29 @@ __global__ void hip_copy2d_n(T* dst, const T* src, size_t width, size_t height, } } // namespace +//Get the allocated size +hipError_t ihipMemPtrGetInfo(void* ptr, size_t* size) { + hipError_t e = hipSuccess; + if (ptr != nullptr && size != nullptr) { + *size = 0; + hc::accelerator acc; +#if (__hcc_workweek__ >= 17332) + hc::AmPointerInfo amPointerInfo(NULL, NULL, NULL, 0, acc, 0, 0); +#else + hc::AmPointerInfo amPointerInfo(NULL, NULL, 0, acc, 0, 0); +#endif + am_status_t status = hc::am_memtracker_getinfo(&amPointerInfo, ptr); + if (status == AM_SUCCESS) { + *size = amPointerInfo._sizeBytes; + } else { + e = hipErrorInvalidValue; + } + } else { + e = hipErrorInvalidValue; + } + return e; +} + template void ihipMemsetKernel(hipStream_t stream, T* ptr, T val, size_t count) { static constexpr uint32_t block_dim = 256; @@ -1532,13 +1555,17 @@ typedef enum ihipMemsetDataType { ihipMemsetDataTypeInt = 2 }ihipMemsetDataType; -hipError_t ihipMemset(void* dst, int value, size_t count, hipStream_t stream, enum ihipMemsetDataType copyDataType ) +hipError_t ihipMemset(void* dst, int value, size_t count, hipStream_t stream, enum ihipMemsetDataType copyDataType) { hipError_t e = hipSuccess; if (count == 0) return e; - if (stream && (dst != NULL)) { + size_t allocSize = 0; + bool isInbound = (ihipMemPtrGetInfo(dst, &allocSize) == hipSuccess); + isInbound &= (allocSize >= count); + + if (stream && (dst != NULL) && isInbound) { if(copyDataType == ihipMemsetDataTypeChar){ if ((count & 0x3) == 0) { // use a faster dword-per-workitem copy: @@ -1898,25 +1925,7 @@ hipError_t hipMemGetInfo(size_t* free, size_t* total) { hipError_t hipMemPtrGetInfo(void* ptr, size_t* size) { HIP_INIT_API(hipMemPtrGetInfo, ptr, size); - hipError_t e = hipSuccess; - - if (ptr != nullptr && size != nullptr) { - hc::accelerator acc; -#if (__hcc_workweek__ >= 17332) - hc::AmPointerInfo amPointerInfo(NULL, NULL, NULL, 0, acc, 0, 0); -#else - hc::AmPointerInfo amPointerInfo(NULL, NULL, 0, acc, 0, 0); -#endif - am_status_t status = hc::am_memtracker_getinfo(&amPointerInfo, ptr); - if (status == AM_SUCCESS) { - *size = amPointerInfo._sizeBytes; - } else { - e = hipErrorInvalidValue; - } - } else { - e = hipErrorInvalidValue; - } - return ihipLogStatus(e); + return ihipLogStatus(ihipMemPtrGetInfo(ptr, size)); }