Refactor copy - place common code in resolveMemoryKind.

[ROCm/hip commit: 194d02ac5a]
Этот коммит содержится в:
Ben Sander
2016-03-19 22:56:10 -05:00
родитель 4c77ecef9a
Коммит b520a34579
+40 -43
Просмотреть файл
@@ -2228,6 +2228,25 @@ hipError_t hipMemcpyToSymbol(const char* symbolName, const void *src, size_t cou
return ihipLogStatus(hipSuccess);
}
// Resolve hipMemcpyDefault to a known type.
hipMemcpyKind resolveMemcpyDirection(bool srcInDeviceMem, bool dstInDeviceMem)
{
hipMemcpyKind kind = hipMemcpyDefault;
if (!srcInDeviceMem && !dstInDeviceMem) {
kind = hipMemcpyHostToHost;
} else if (!srcInDeviceMem && dstInDeviceMem) {
kind = hipMemcpyHostToDevice;
} else if (srcInDeviceMem && !dstInDeviceMem) {
kind = hipMemcpyDeviceToHost;
} else if (srcInDeviceMem && dstInDeviceMem) {
kind = hipMemcpyDeviceToDevice;
}
assert (kind != hipMemcpyDefault);
return kind;
}
void ihipStream_t::copySync(void* dst, const void* src, size_t sizeBytes, hipMemcpyKind kind)
@@ -2242,34 +2261,23 @@ void ihipStream_t::copySync(void* dst, const void* src, size_t sizeBytes, hipMem
hc::AmPointerInfo dstPtrInfo(NULL, NULL, 0, acc, 0, 0);
hc::AmPointerInfo srcPtrInfo(NULL, NULL, 0, acc, 0, 0);
bool dstNotTracked = (hc::am_memtracker_getinfo(&dstPtrInfo, dst) != AM_SUCCESS);
bool srcNotTracked = (hc::am_memtracker_getinfo(&srcPtrInfo, src) != AM_SUCCESS);
bool dstTracked = (hc::am_memtracker_getinfo(&dstPtrInfo, dst) == AM_SUCCESS);
bool srcTracked = (hc::am_memtracker_getinfo(&srcPtrInfo, src) == AM_SUCCESS);
// Resolve default to a specific Kind so we know which algorithm to use:
if (kind == hipMemcpyDefault) {
bool dstIsHost = (dstNotTracked || !dstPtrInfo._isInDeviceMem);
bool srcIsHost = (srcNotTracked || !srcPtrInfo._isInDeviceMem);
if (srcIsHost && !dstIsHost) {
kind = hipMemcpyHostToDevice;
} else if (!srcIsHost && dstIsHost) {
kind = hipMemcpyDeviceToHost;
} else if (srcIsHost && dstIsHost) {
kind = hipMemcpyHostToHost;
} else if (!srcIsHost && !dstIsHost) {
kind = hipMemcpyDeviceToDevice;
} else {
throw ihipException(hipErrorInvalidMemcpyDirection);
}
}
bool srcInDeviceMem = (srcTracked && srcPtrInfo._isInDeviceMem);
bool dstInDeviceMem = (dstTracked && dstPtrInfo._isInDeviceMem);
kind = resolveMemcpyDirection(srcInDeviceMem, dstInDeviceMem);
};
hsa_signal_t depSignal;
if ((kind == hipMemcpyHostToDevice) && (srcNotTracked)) {
if ((kind == hipMemcpyHostToDevice) && (!srcTracked)) {
int depSignalCnt = preCopyCommand(NULL, &depSignal, ihipCommandCopyH2D);
if (HIP_STAGING_BUFFERS) {
tprintf(DB_COPY1, "D2H && dstNotTracked: staged copy H2D dst=%p src=%p sz=%zu\n", dst, src, sizeBytes);
tprintf(DB_COPY1, "D2H && !dstTracked: staged copy H2D dst=%p src=%p sz=%zu\n", dst, src, sizeBytes);
if (HIP_PININPLACE) {
device->_staging_buffer[0]->CopyHostToDevicePinInPlace(dst, src, sizeBytes, depSignalCnt ? &depSignal : NULL);
@@ -2281,13 +2289,13 @@ void ihipStream_t::copySync(void* dst, const void* src, size_t sizeBytes, hipMem
this->wait(true);
} else {
// TODO - remove, slow path.
tprintf(DB_COPY1, "H2D && srcNotTracked: am_copy dst=%p src=%p sz=%zu\n", dst, src, sizeBytes);
tprintf(DB_COPY1, "H2D && ! srcTracked: am_copy dst=%p src=%p sz=%zu\n", dst, src, sizeBytes);
hc::am_copy(dst, src, sizeBytes);
}
} else if ((kind == hipMemcpyDeviceToHost) && (dstNotTracked)) {
} else if ((kind == hipMemcpyDeviceToHost) && (!dstTracked)) {
int depSignalCnt = preCopyCommand(NULL, &depSignal, ihipCommandCopyD2H);
if (HIP_STAGING_BUFFERS) {
tprintf(DB_COPY1, "D2H && dstNotTracked: staged copy D2H dst=%p src=%p sz=%zu\n", dst, src, sizeBytes);
tprintf(DB_COPY1, "D2H && !dstTracked: staged copy D2H dst=%p src=%p sz=%zu\n", dst, src, sizeBytes);
//printf ("staged-copy- read dep signals\n");
device->_staging_buffer[1]->CopyDeviceToHost(dst, src, sizeBytes, depSignalCnt ? &depSignal : NULL);
@@ -2296,7 +2304,7 @@ void ihipStream_t::copySync(void* dst, const void* src, size_t sizeBytes, hipMem
} else {
// TODO - remove, slow path.
tprintf(DB_COPY1, "D2H && dstNotTracked: am_copy dst=%p src=%p sz=%zu\n", dst, src, sizeBytes);
tprintf(DB_COPY1, "D2H && !dstTracked: am_copy dst=%p src=%p sz=%zu\n", dst, src, sizeBytes);
hc::am_copy(dst, src, sizeBytes);
}
} else if (kind == hipMemcpyHostToHost) {
@@ -2341,6 +2349,8 @@ void ihipStream_t::copySync(void* dst, const void* src, size_t sizeBytes, hipMem
}
void ihipStream_t::copyAsync(void* dst, const void* src, size_t sizeBytes, hipMemcpyKind kind)
{
ihipDevice_t *device = this->getDevice();
@@ -2364,13 +2374,11 @@ void ihipStream_t::copyAsync(void* dst, const void* src, size_t sizeBytes, hipMe
bool trueAsync = true;
hc::accelerator acc;
hc::AmPointerInfo dstAm(NULL, NULL, 0, acc, 0, 0);
hc::AmPointerInfo srcAm(NULL, NULL, 0, acc, 0, 0);
bool dstTracked = (hc::am_memtracker_getinfo(&dstAm, dst) == AM_SUCCESS);
bool srcTracked = (hc::am_memtracker_getinfo(&srcAm, src) == AM_SUCCESS);
hc::AmPointerInfo dstPtrInfo(NULL, NULL, 0, acc, 0, 0);
hc::AmPointerInfo srcPtrInfo(NULL, NULL, 0, acc, 0, 0);
bool dstTracked = (hc::am_memtracker_getinfo(&dstPtrInfo, dst) == AM_SUCCESS);
bool srcTracked = (hc::am_memtracker_getinfo(&srcPtrInfo, src) == AM_SUCCESS);
bool dstInDeviceMem = (dstTracked && dstAm._isInDeviceMem);
bool srcInDeviceMem = (srcTracked && srcAm._isInDeviceMem);
// "tracked" really indicates if the pointer's virtual address is available in the GPU address space.
// If both pointers are not tracked, we need to fall back to a sync copy.
@@ -2379,20 +2387,9 @@ void ihipStream_t::copyAsync(void* dst, const void* src, size_t sizeBytes, hipMe
}
if (kind == hipMemcpyDefault) {
if (!dstInDeviceMem && !srcInDeviceMem) {
kind = hipMemcpyHostToHost;
} else if (dstInDeviceMem && !srcInDeviceMem) {
kind = hipMemcpyHostToDevice;
} else if (!dstInDeviceMem && srcInDeviceMem) {
kind = hipMemcpyDeviceToHost;
} else if (dstInDeviceMem && srcInDeviceMem) {
kind = hipMemcpyDeviceToHost;
}
// If we still couldn't determine direction, flag error here:
if (kind == hipMemcpyDefault) {
throw ihipException(hipErrorInvalidMemcpyDirection);
}
bool srcInDeviceMem = (srcTracked && srcPtrInfo._isInDeviceMem);
bool dstInDeviceMem = (dstTracked && dstPtrInfo._isInDeviceMem);
kind = resolveMemcpyDirection(srcInDeviceMem, dstInDeviceMem);
}