Adding bounds check before hipMemset (#1190)
* Adding bounds check in ihipMemset
* Adding ihipMemPtrGetInfo to hipMemPtrGetInfo
[ROCm/hip commit: 5ed16432f8]
Этот коммит содержится в:
коммит произвёл
Maneesh Gupta
родитель
8cd7322740
Коммит
c7f8ffe41e
@@ -1506,6 +1506,29 @@ __global__ void hip_copy2d_n(T* dst, const T* src, size_t width, size_t height,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
//Get the allocated size
|
||||
hipError_t ihipMemPtrGetInfo(void* ptr, size_t* size) {
|
||||
hipError_t e = hipSuccess;
|
||||
if (ptr != nullptr && size != nullptr) {
|
||||
*size = 0;
|
||||
hc::accelerator acc;
|
||||
#if (__hcc_workweek__ >= 17332)
|
||||
hc::AmPointerInfo amPointerInfo(NULL, NULL, NULL, 0, acc, 0, 0);
|
||||
#else
|
||||
hc::AmPointerInfo amPointerInfo(NULL, NULL, 0, acc, 0, 0);
|
||||
#endif
|
||||
am_status_t status = hc::am_memtracker_getinfo(&amPointerInfo, ptr);
|
||||
if (status == AM_SUCCESS) {
|
||||
*size = amPointerInfo._sizeBytes;
|
||||
} else {
|
||||
e = hipErrorInvalidValue;
|
||||
}
|
||||
} else {
|
||||
e = hipErrorInvalidValue;
|
||||
}
|
||||
return e;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ihipMemsetKernel(hipStream_t stream, T* ptr, T val, size_t count) {
|
||||
static constexpr uint32_t block_dim = 256;
|
||||
@@ -1532,13 +1555,17 @@ typedef enum ihipMemsetDataType {
|
||||
ihipMemsetDataTypeInt = 2
|
||||
}ihipMemsetDataType;
|
||||
|
||||
hipError_t ihipMemset(void* dst, int value, size_t count, hipStream_t stream, enum ihipMemsetDataType copyDataType )
|
||||
hipError_t ihipMemset(void* dst, int value, size_t count, hipStream_t stream, enum ihipMemsetDataType copyDataType)
|
||||
{
|
||||
hipError_t e = hipSuccess;
|
||||
|
||||
if (count == 0) return e;
|
||||
|
||||
if (stream && (dst != NULL)) {
|
||||
size_t allocSize = 0;
|
||||
bool isInbound = (ihipMemPtrGetInfo(dst, &allocSize) == hipSuccess);
|
||||
isInbound &= (allocSize >= count);
|
||||
|
||||
if (stream && (dst != NULL) && isInbound) {
|
||||
if(copyDataType == ihipMemsetDataTypeChar){
|
||||
if ((count & 0x3) == 0) {
|
||||
// use a faster dword-per-workitem copy:
|
||||
@@ -1898,25 +1925,7 @@ hipError_t hipMemGetInfo(size_t* free, size_t* total) {
|
||||
hipError_t hipMemPtrGetInfo(void* ptr, size_t* size) {
|
||||
HIP_INIT_API(hipMemPtrGetInfo, ptr, size);
|
||||
|
||||
hipError_t e = hipSuccess;
|
||||
|
||||
if (ptr != nullptr && size != nullptr) {
|
||||
hc::accelerator acc;
|
||||
#if (__hcc_workweek__ >= 17332)
|
||||
hc::AmPointerInfo amPointerInfo(NULL, NULL, NULL, 0, acc, 0, 0);
|
||||
#else
|
||||
hc::AmPointerInfo amPointerInfo(NULL, NULL, 0, acc, 0, 0);
|
||||
#endif
|
||||
am_status_t status = hc::am_memtracker_getinfo(&amPointerInfo, ptr);
|
||||
if (status == AM_SUCCESS) {
|
||||
*size = amPointerInfo._sizeBytes;
|
||||
} else {
|
||||
e = hipErrorInvalidValue;
|
||||
}
|
||||
} else {
|
||||
e = hipErrorInvalidValue;
|
||||
}
|
||||
return ihipLogStatus(e);
|
||||
return ihipLogStatus(ihipMemPtrGetInfo(ptr, size));
|
||||
}
|
||||
|
||||
|
||||
|
||||
Ссылка в новой задаче
Block a user