Adding bounds check before hipMemset (#1190)

* Adding bounds check in ihipMemset

* Adding ihipMemPtrGetInfo to hipMemPtrGetInfo


[ROCm/hip commit: 5ed16432f8]
Этот коммит содержится в:
Jatin Chaudhary
2019-07-08 16:30:38 +05:30
коммит произвёл Maneesh Gupta
родитель 8cd7322740
Коммит c7f8ffe41e
+30 -21
Просмотреть файл
@@ -1506,6 +1506,29 @@ __global__ void hip_copy2d_n(T* dst, const T* src, size_t width, size_t height,
}
} // namespace
//Get the allocated size
hipError_t ihipMemPtrGetInfo(void* ptr, size_t* size) {
hipError_t e = hipSuccess;
if (ptr != nullptr && size != nullptr) {
*size = 0;
hc::accelerator acc;
#if (__hcc_workweek__ >= 17332)
hc::AmPointerInfo amPointerInfo(NULL, NULL, NULL, 0, acc, 0, 0);
#else
hc::AmPointerInfo amPointerInfo(NULL, NULL, 0, acc, 0, 0);
#endif
am_status_t status = hc::am_memtracker_getinfo(&amPointerInfo, ptr);
if (status == AM_SUCCESS) {
*size = amPointerInfo._sizeBytes;
} else {
e = hipErrorInvalidValue;
}
} else {
e = hipErrorInvalidValue;
}
return e;
}
template <typename T>
void ihipMemsetKernel(hipStream_t stream, T* ptr, T val, size_t count) {
static constexpr uint32_t block_dim = 256;
@@ -1532,13 +1555,17 @@ typedef enum ihipMemsetDataType {
ihipMemsetDataTypeInt = 2
}ihipMemsetDataType;
hipError_t ihipMemset(void* dst, int value, size_t count, hipStream_t stream, enum ihipMemsetDataType copyDataType )
hipError_t ihipMemset(void* dst, int value, size_t count, hipStream_t stream, enum ihipMemsetDataType copyDataType)
{
hipError_t e = hipSuccess;
if (count == 0) return e;
if (stream && (dst != NULL)) {
size_t allocSize = 0;
bool isInbound = (ihipMemPtrGetInfo(dst, &allocSize) == hipSuccess);
isInbound &= (allocSize >= count);
if (stream && (dst != NULL) && isInbound) {
if(copyDataType == ihipMemsetDataTypeChar){
if ((count & 0x3) == 0) {
// use a faster dword-per-workitem copy:
@@ -1898,25 +1925,7 @@ hipError_t hipMemGetInfo(size_t* free, size_t* total) {
hipError_t hipMemPtrGetInfo(void* ptr, size_t* size) {
HIP_INIT_API(hipMemPtrGetInfo, ptr, size);
hipError_t e = hipSuccess;
if (ptr != nullptr && size != nullptr) {
hc::accelerator acc;
#if (__hcc_workweek__ >= 17332)
hc::AmPointerInfo amPointerInfo(NULL, NULL, NULL, 0, acc, 0, 0);
#else
hc::AmPointerInfo amPointerInfo(NULL, NULL, 0, acc, 0, 0);
#endif
am_status_t status = hc::am_memtracker_getinfo(&amPointerInfo, ptr);
if (status == AM_SUCCESS) {
*size = amPointerInfo._sizeBytes;
} else {
e = hipErrorInvalidValue;
}
} else {
e = hipErrorInvalidValue;
}
return ihipLogStatus(e);
return ihipLogStatus(ihipMemPtrGetInfo(ptr, size));
}