diff --git a/hipamd/src/hip_memory.cpp b/hipamd/src/hip_memory.cpp index c536164f45..43399bfdb5 100644 --- a/hipamd/src/hip_memory.cpp +++ b/hipamd/src/hip_memory.cpp @@ -1273,8 +1273,18 @@ hipError_t hipMemcpyToSymbol(void* dst, const void* src, size_t count, tprintf(DB_MEM, " symbol '%s' resolved to address:%p\n", symbol_name, dst); - return ihipLogStatus( - hipMemcpy(static_cast(dst) + offset, src, count, kind)); + if (dst == nullptr) { + return ihipLogStatus(hipErrorInvalidSymbol); + } + + if (kind == hipMemcpyDeviceToHost || kind == hipMemcpyHostToHost) { + return ihipLogStatus(hipErrorInvalidMemcpyDirection); + } else if (kind == hipMemcpyDeviceToDevice) { + return ihipLogStatus(hipErrorInvalidValue); + } + + return ihipLogStatus(hip_internal::memcpySync(static_cast(dst)+offset, src, count, kind, + hipStreamNull)); } hipError_t hipMemcpyFromSymbol(void* dst, const void* src, size_t count, @@ -1285,8 +1295,18 @@ hipError_t hipMemcpyFromSymbol(void* dst, const void* src, size_t count, tprintf(DB_MEM, " symbol '%s' resolved to address:%p\n", symbol_name, dst); - return ihipLogStatus( - hipMemcpy(dst, static_cast(src) + offset, count, kind)); + if (src == nullptr || dst == nullptr) { + return ihipLogStatus(hipErrorInvalidSymbol); + } + + if (kind == hipMemcpyHostToDevice || kind == hipMemcpyHostToHost) { + return ihipLogStatus(hipErrorInvalidMemcpyDirection); + } else if (kind == hipMemcpyDeviceToDevice) { + return ihipLogStatus(hipErrorInvalidValue); + } + + return ihipLogStatus(hip_internal::memcpySync(dst, static_cast(src)+offset, count, kind, + hipStreamNull)); } @@ -1301,11 +1321,17 @@ hipError_t hipMemcpyToSymbolAsync(void* dst, const void* src, size_t count, if (dst == nullptr) { return ihipLogStatus(hipErrorInvalidSymbol); } + + if (kind == hipMemcpyDeviceToHost || kind == hipMemcpyHostToHost) { + return ihipLogStatus(hipErrorInvalidMemcpyDirection); + } else if (kind == hipMemcpyDeviceToDevice) { + return ihipLogStatus(hipErrorInvalidValue); + } hipError_t e = hipSuccess; if (stream) { try { - hip_internal::memcpyAsync((char*)dst+offset, src, count, kind, stream); + hip_internal::memcpyAsync(static_cast(dst)+offset, src, count, kind, stream); } catch (ihipException& ex) { e = ex._code; } @@ -1327,12 +1353,18 @@ hipError_t hipMemcpyFromSymbolAsync(void* dst, const void* src, size_t count, if (src == nullptr || dst == nullptr) { return ihipLogStatus(hipErrorInvalidSymbol); } + + if (kind == hipMemcpyHostToDevice || kind == hipMemcpyHostToHost) { + return ihipLogStatus(hipErrorInvalidMemcpyDirection); + } else if (kind == hipMemcpyDeviceToDevice) { + return ihipLogStatus(hipErrorInvalidValue); + } hipError_t e = hipSuccess; stream = ihipSyncAndResolveStream(stream); if (stream) { try { - hip_internal::memcpyAsync(dst, (char*)src+offset, count, kind, stream); + hip_internal::memcpyAsync(dst, static_cast(src)+offset, count, kind, stream); } catch (ihipException& ex) { e = ex._code; }