diff --git a/hipamd/src/hip_memory.cpp b/hipamd/src/hip_memory.cpp index e7bc348951..0aee149c9f 100644 --- a/hipamd/src/hip_memory.cpp +++ b/hipamd/src/hip_memory.cpp @@ -2219,16 +2219,40 @@ hipError_t hipMemset(void* dst, int value, size_t sizeBytes) { return ihipLogStatus(ihipMemsetSync(dst, value, sizeBytes, nullptr, ihipMemsetDataTypeChar)); } +hipError_t ihipMemsetND(void* dst, size_t pitch, int value, size_t width, size_t height, size_t setHeight,size_t depth, + hipStream_t stream, enum ihipMemsetDataType copyDataType, bool async) { + size_t sizeBytes =0; + hipError_t hipStatus = hipSuccess; + if ((pitch == width) && (height == setHeight)) { + sizeBytes = pitch * setHeight * depth; + if(async) + return ihipMemsetAsync(dst, value, sizeBytes, stream, copyDataType); + else + return ihipMemsetSync(dst, value, sizeBytes, nullptr, copyDataType); + } else { + for(size_t i = 0; i < depth; ++i) { + for(size_t j = 0; j < setHeight; ++j) { + void* dstPtr = ((unsigned char*) dst + i * height * pitch + j * pitch); + if(async) + hipStatus = ihipMemsetAsync(dstPtr, value, width, stream, copyDataType); + else + hipStatus = ihipMemsetSync(dstPtr, value, width, nullptr, copyDataType); + if (hipStatus != hipSuccess) + return hipStatus; + } + } + } + return hipStatus; +} + hipError_t hipMemset2D(void* dst, size_t pitch, int value, size_t width, size_t height) { HIP_INIT_SPECIAL_API(hipMemset2D, (TRACE_MCMD), dst, pitch, value, width, height); - size_t sizeBytes = pitch * height; - return ihipLogStatus(ihipMemsetSync(dst, value, sizeBytes, nullptr, ihipMemsetDataTypeChar)); + return ihipLogStatus(ihipMemsetND(dst, pitch, value, width, height, height, 1, hipStreamNull, ihipMemsetDataTypeChar, false)); } hipError_t hipMemset2DAsync(void* dst, size_t pitch, int value, size_t width, size_t height, hipStream_t stream ) { HIP_INIT_SPECIAL_API(hipMemset2DAsync, (TRACE_MCMD), dst, pitch, value, width, height, stream); - size_t sizeBytes = pitch * height; - return ihipLogStatus(ihipMemsetAsync(dst, value, sizeBytes, stream, ihipMemsetDataTypeChar)); + return ihipLogStatus(ihipMemsetND(dst, pitch, value, width, height, height, 1, stream, ihipMemsetDataTypeChar, true)); } hipError_t hipMemsetD8(hipDeviceptr_t dst, unsigned char value, size_t count) { @@ -2258,14 +2282,12 @@ hipError_t hipMemsetD32(hipDeviceptr_t dst, int value, size_t count) { hipError_t hipMemset3D(hipPitchedPtr pitchedDevPtr, int value, hipExtent extent) { HIP_INIT_SPECIAL_API(hipMemset3D, (TRACE_MCMD), &pitchedDevPtr, value, &extent); - size_t sizeBytes = pitchedDevPtr.pitch * extent.height * extent.depth; - return ihipLogStatus(ihipMemsetSync(pitchedDevPtr.ptr, value, sizeBytes, nullptr, ihipMemsetDataTypeChar)); + return ihipLogStatus(ihipMemsetND(pitchedDevPtr.ptr, pitchedDevPtr.pitch ,value, extent.width, pitchedDevPtr.ysize, extent.height, extent.depth, hipStreamNull, ihipMemsetDataTypeChar, false)); } hipError_t hipMemset3DAsync(hipPitchedPtr pitchedDevPtr, int value, hipExtent extent ,hipStream_t stream ) { HIP_INIT_SPECIAL_API(hipMemset3DAsync, (TRACE_MCMD), &pitchedDevPtr, value, &extent); - size_t sizeBytes = pitchedDevPtr.pitch * extent.height * extent.depth; - return ihipLogStatus(ihipMemsetAsync(pitchedDevPtr.ptr, value, sizeBytes, stream, ihipMemsetDataTypeChar)); + return ihipLogStatus(ihipMemsetND(pitchedDevPtr.ptr,pitchedDevPtr.pitch, value, extent.width, pitchedDevPtr.ysize, extent.height, extent.depth, stream, ihipMemsetDataTypeChar, true)); } hipError_t hipMemGetInfo(size_t* free, size_t* total) {