From 9b2f22f7aa59cd2ff7c8014e56f03c00afac38d8 Mon Sep 17 00:00:00 2001 From: satyanveshd <53337087+satyanveshd@users.noreply.github.com> Date: Tue, 7 Jan 2020 08:11:53 +0530 Subject: [PATCH] hipMemcpy[To/From]Symbol(Async) fixes (#1774) --- hipamd/src/hip_memory.cpp | 44 +++++++++++++++++++++++++++++++++------ 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/hipamd/src/hip_memory.cpp b/hipamd/src/hip_memory.cpp index c536164f45..43399bfdb5 100644 --- a/hipamd/src/hip_memory.cpp +++ b/hipamd/src/hip_memory.cpp @@ -1273,8 +1273,18 @@ hipError_t hipMemcpyToSymbol(void* dst, const void* src, size_t count, tprintf(DB_MEM, " symbol '%s' resolved to address:%p\n", symbol_name, dst); - return ihipLogStatus( - hipMemcpy(static_cast(dst) + offset, src, count, kind)); + if (dst == nullptr) { + return ihipLogStatus(hipErrorInvalidSymbol); + } + + if (kind == hipMemcpyDeviceToHost || kind == hipMemcpyHostToHost) { + return ihipLogStatus(hipErrorInvalidMemcpyDirection); + } else if (kind == hipMemcpyDeviceToDevice) { + return ihipLogStatus(hipErrorInvalidValue); + } + + return ihipLogStatus(hip_internal::memcpySync(static_cast(dst)+offset, src, count, kind, + hipStreamNull)); } hipError_t hipMemcpyFromSymbol(void* dst, const void* src, size_t count, @@ -1285,8 +1295,18 @@ hipError_t hipMemcpyFromSymbol(void* dst, const void* src, size_t count, tprintf(DB_MEM, " symbol '%s' resolved to address:%p\n", symbol_name, dst); - return ihipLogStatus( - hipMemcpy(dst, static_cast(src) + offset, count, kind)); + if (src == nullptr || dst == nullptr) { + return ihipLogStatus(hipErrorInvalidSymbol); + } + + if (kind == hipMemcpyHostToDevice || kind == hipMemcpyHostToHost) { + return ihipLogStatus(hipErrorInvalidMemcpyDirection); + } else if (kind == hipMemcpyDeviceToDevice) { + return ihipLogStatus(hipErrorInvalidValue); + } + + return ihipLogStatus(hip_internal::memcpySync(dst, static_cast(src)+offset, count, kind, + hipStreamNull)); } @@ -1301,11 +1321,17 @@ hipError_t hipMemcpyToSymbolAsync(void* dst, const void* src, size_t count, if (dst == nullptr) { return ihipLogStatus(hipErrorInvalidSymbol); } + + if (kind == hipMemcpyDeviceToHost || kind == hipMemcpyHostToHost) { + return ihipLogStatus(hipErrorInvalidMemcpyDirection); + } else if (kind == hipMemcpyDeviceToDevice) { + return ihipLogStatus(hipErrorInvalidValue); + } hipError_t e = hipSuccess; if (stream) { try { - hip_internal::memcpyAsync((char*)dst+offset, src, count, kind, stream); + hip_internal::memcpyAsync(static_cast(dst)+offset, src, count, kind, stream); } catch (ihipException& ex) { e = ex._code; } @@ -1327,12 +1353,18 @@ hipError_t hipMemcpyFromSymbolAsync(void* dst, const void* src, size_t count, if (src == nullptr || dst == nullptr) { return ihipLogStatus(hipErrorInvalidSymbol); } + + if (kind == hipMemcpyHostToDevice || kind == hipMemcpyHostToHost) { + return ihipLogStatus(hipErrorInvalidMemcpyDirection); + } else if (kind == hipMemcpyDeviceToDevice) { + return ihipLogStatus(hipErrorInvalidValue); + } hipError_t e = hipSuccess; stream = ihipSyncAndResolveStream(stream); if (stream) { try { - hip_internal::memcpyAsync(dst, (char*)src+offset, count, kind, stream); + hip_internal::memcpyAsync(dst, static_cast(src)+offset, count, kind, stream); } catch (ihipException& ex) { e = ex._code; }