diff --git a/projects/hip/src/hip_hcc.cpp b/projects/hip/src/hip_hcc.cpp index 0bcbfd2127..b69daf3684 100644 --- a/projects/hip/src/hip_hcc.cpp +++ b/projects/hip/src/hip_hcc.cpp @@ -2228,6 +2228,25 @@ hipError_t hipMemcpyToSymbol(const char* symbolName, const void *src, size_t cou return ihipLogStatus(hipSuccess); } +// Resolve hipMemcpyDefault to a known type. +hipMemcpyKind resolveMemcpyDirection(bool srcInDeviceMem, bool dstInDeviceMem) +{ + hipMemcpyKind kind = hipMemcpyDefault; + + if (!srcInDeviceMem && !dstInDeviceMem) { + kind = hipMemcpyHostToHost; + } else if (!srcInDeviceMem && dstInDeviceMem) { + kind = hipMemcpyHostToDevice; + } else if (srcInDeviceMem && !dstInDeviceMem) { + kind = hipMemcpyDeviceToHost; + } else if (srcInDeviceMem && dstInDeviceMem) { + kind = hipMemcpyDeviceToDevice; + } + + assert (kind != hipMemcpyDefault); + + return kind; +} void ihipStream_t::copySync(void* dst, const void* src, size_t sizeBytes, hipMemcpyKind kind) @@ -2242,34 +2261,23 @@ void ihipStream_t::copySync(void* dst, const void* src, size_t sizeBytes, hipMem hc::AmPointerInfo dstPtrInfo(NULL, NULL, 0, acc, 0, 0); hc::AmPointerInfo srcPtrInfo(NULL, NULL, 0, acc, 0, 0); - bool dstNotTracked = (hc::am_memtracker_getinfo(&dstPtrInfo, dst) != AM_SUCCESS); - bool srcNotTracked = (hc::am_memtracker_getinfo(&srcPtrInfo, src) != AM_SUCCESS); + bool dstTracked = (hc::am_memtracker_getinfo(&dstPtrInfo, dst) == AM_SUCCESS); + bool srcTracked = (hc::am_memtracker_getinfo(&srcPtrInfo, src) == AM_SUCCESS); // Resolve default to a specific Kind so we know which algorithm to use: if (kind == hipMemcpyDefault) { - bool dstIsHost = (dstNotTracked || !dstPtrInfo._isInDeviceMem); - bool srcIsHost = (srcNotTracked || !srcPtrInfo._isInDeviceMem); - if (srcIsHost && !dstIsHost) { - kind = hipMemcpyHostToDevice; - } else if (!srcIsHost && dstIsHost) { - kind = hipMemcpyDeviceToHost; - } else if (srcIsHost && dstIsHost) { - kind = hipMemcpyHostToHost; - } else if (!srcIsHost && !dstIsHost) { - kind = hipMemcpyDeviceToDevice; - } else { - throw ihipException(hipErrorInvalidMemcpyDirection); - } - } + bool srcInDeviceMem = (srcTracked && srcPtrInfo._isInDeviceMem); + bool dstInDeviceMem = (dstTracked && dstPtrInfo._isInDeviceMem); + kind = resolveMemcpyDirection(srcInDeviceMem, dstInDeviceMem); + }; hsa_signal_t depSignal; - - if ((kind == hipMemcpyHostToDevice) && (srcNotTracked)) { + if ((kind == hipMemcpyHostToDevice) && (!srcTracked)) { int depSignalCnt = preCopyCommand(NULL, &depSignal, ihipCommandCopyH2D); if (HIP_STAGING_BUFFERS) { - tprintf(DB_COPY1, "D2H && dstNotTracked: staged copy H2D dst=%p src=%p sz=%zu\n", dst, src, sizeBytes); + tprintf(DB_COPY1, "D2H && !dstTracked: staged copy H2D dst=%p src=%p sz=%zu\n", dst, src, sizeBytes); if (HIP_PININPLACE) { device->_staging_buffer[0]->CopyHostToDevicePinInPlace(dst, src, sizeBytes, depSignalCnt ? &depSignal : NULL); @@ -2281,13 +2289,13 @@ void ihipStream_t::copySync(void* dst, const void* src, size_t sizeBytes, hipMem this->wait(true); } else { // TODO - remove, slow path. - tprintf(DB_COPY1, "H2D && srcNotTracked: am_copy dst=%p src=%p sz=%zu\n", dst, src, sizeBytes); + tprintf(DB_COPY1, "H2D && ! srcTracked: am_copy dst=%p src=%p sz=%zu\n", dst, src, sizeBytes); hc::am_copy(dst, src, sizeBytes); } - } else if ((kind == hipMemcpyDeviceToHost) && (dstNotTracked)) { + } else if ((kind == hipMemcpyDeviceToHost) && (!dstTracked)) { int depSignalCnt = preCopyCommand(NULL, &depSignal, ihipCommandCopyD2H); if (HIP_STAGING_BUFFERS) { - tprintf(DB_COPY1, "D2H && dstNotTracked: staged copy D2H dst=%p src=%p sz=%zu\n", dst, src, sizeBytes); + tprintf(DB_COPY1, "D2H && !dstTracked: staged copy D2H dst=%p src=%p sz=%zu\n", dst, src, sizeBytes); //printf ("staged-copy- read dep signals\n"); device->_staging_buffer[1]->CopyDeviceToHost(dst, src, sizeBytes, depSignalCnt ? &depSignal : NULL); @@ -2296,7 +2304,7 @@ void ihipStream_t::copySync(void* dst, const void* src, size_t sizeBytes, hipMem } else { // TODO - remove, slow path. - tprintf(DB_COPY1, "D2H && dstNotTracked: am_copy dst=%p src=%p sz=%zu\n", dst, src, sizeBytes); + tprintf(DB_COPY1, "D2H && !dstTracked: am_copy dst=%p src=%p sz=%zu\n", dst, src, sizeBytes); hc::am_copy(dst, src, sizeBytes); } } else if (kind == hipMemcpyHostToHost) { @@ -2341,6 +2349,8 @@ void ihipStream_t::copySync(void* dst, const void* src, size_t sizeBytes, hipMem } + + void ihipStream_t::copyAsync(void* dst, const void* src, size_t sizeBytes, hipMemcpyKind kind) { ihipDevice_t *device = this->getDevice(); @@ -2364,13 +2374,11 @@ void ihipStream_t::copyAsync(void* dst, const void* src, size_t sizeBytes, hipMe bool trueAsync = true; hc::accelerator acc; - hc::AmPointerInfo dstAm(NULL, NULL, 0, acc, 0, 0); - hc::AmPointerInfo srcAm(NULL, NULL, 0, acc, 0, 0); - bool dstTracked = (hc::am_memtracker_getinfo(&dstAm, dst) == AM_SUCCESS); - bool srcTracked = (hc::am_memtracker_getinfo(&srcAm, src) == AM_SUCCESS); + hc::AmPointerInfo dstPtrInfo(NULL, NULL, 0, acc, 0, 0); + hc::AmPointerInfo srcPtrInfo(NULL, NULL, 0, acc, 0, 0); + bool dstTracked = (hc::am_memtracker_getinfo(&dstPtrInfo, dst) == AM_SUCCESS); + bool srcTracked = (hc::am_memtracker_getinfo(&srcPtrInfo, src) == AM_SUCCESS); - bool dstInDeviceMem = (dstTracked && dstAm._isInDeviceMem); - bool srcInDeviceMem = (srcTracked && srcAm._isInDeviceMem); // "tracked" really indicates if the pointer's virtual address is available in the GPU address space. // If both pointers are not tracked, we need to fall back to a sync copy. @@ -2379,20 +2387,9 @@ void ihipStream_t::copyAsync(void* dst, const void* src, size_t sizeBytes, hipMe } if (kind == hipMemcpyDefault) { - if (!dstInDeviceMem && !srcInDeviceMem) { - kind = hipMemcpyHostToHost; - } else if (dstInDeviceMem && !srcInDeviceMem) { - kind = hipMemcpyHostToDevice; - } else if (!dstInDeviceMem && srcInDeviceMem) { - kind = hipMemcpyDeviceToHost; - } else if (dstInDeviceMem && srcInDeviceMem) { - kind = hipMemcpyDeviceToHost; - } - - // If we still couldn't determine direction, flag error here: - if (kind == hipMemcpyDefault) { - throw ihipException(hipErrorInvalidMemcpyDirection); - } + bool srcInDeviceMem = (srcTracked && srcPtrInfo._isInDeviceMem); + bool dstInDeviceMem = (dstTracked && dstPtrInfo._isInDeviceMem); + kind = resolveMemcpyDirection(srcInDeviceMem, dstInDeviceMem); }