Refactor the __device__ versions of memset and memcpy to be less awkward i.e. not return nullptr as opposed to the destination pointer (it can only be assumed it was done for maximum confusion) and actually unroll as they claim to. Change all of the {to, from}Symbol functions to use hipModuleGetGlobal, as opposed to hc::accelerator::get_symbol_address which is no longer valid with module based dispatch.

Этот коммит содержится в:
Alex Voicu
2017-11-21 02:40:34 +00:00
родитель 1824fb7698
Коммит 9d088d2283
5 изменённых файлов: 206 добавлений и 113 удалений
+18 -6
Просмотреть файл
@@ -715,7 +715,10 @@ hipError_t hipMemcpyToSymbol(const void* symbolName, const void *src, size_t cou
hc::accelerator acc = ctx->getDevice()->_acc;
void *dst = acc.get_symbol_address((const char*) symbolName);
hipDeviceptr_t dst = nullptr;
size_t byte_cnt = 0u;
auto status = hipModuleGetGlobal(
&dst, &byte_cnt, 0, static_cast<const char*>(symbolName));
tprintf(DB_MEM, " symbol '%s' resolved to address:%p\n", symbolName, dst);
if(dst == nullptr)
@@ -750,7 +753,10 @@ hipError_t hipMemcpyFromSymbol(void* dst, const void* symbolName, size_t count,
hc::accelerator acc = ctx->getDevice()->_acc;
void *src = acc.get_symbol_address((const char*) symbolName);
hipDeviceptr_t src = nullptr;
size_t byte_cnt = 0u;
auto status = hipModuleGetGlobal(
&src, &byte_cnt, 0, static_cast<const char*>(symbolName));
tprintf(DB_MEM, " symbol '%s' resolved to address:%p\n", symbolName, dst);
if(dst == nullptr)
@@ -787,7 +793,10 @@ hipError_t hipMemcpyToSymbolAsync(const void* symbolName, const void *src, size_
hc::accelerator acc = ctx->getDevice()->_acc;
void *dst = acc.get_symbol_address((const char*) symbolName);
hipDeviceptr_t dst = nullptr;
size_t byte_cnt = 0u;
auto status = hipModuleGetGlobal(
&dst, &byte_cnt, 0, static_cast<const char*>(symbolName));
tprintf(DB_MEM, " symbol '%s' resolved to address:%p\n", symbolName, dst);
if(dst == nullptr)
@@ -825,7 +834,10 @@ hipError_t hipMemcpyFromSymbolAsync(void* dst, const void* symbolName, size_t co
hc::accelerator acc = ctx->getDevice()->_acc;
void *src = acc.get_symbol_address((const char*) symbolName);
hipDeviceptr_t src = nullptr;
size_t byte_cnt = 0u;
auto status = hipModuleGetGlobal(
&src, &byte_cnt, 0, static_cast<const char*>(symbolName));
tprintf(DB_MEM, " symbol '%s' resolved to address:%p\n", symbolName, src);
if(src == nullptr || dst == nullptr)
@@ -1171,9 +1183,9 @@ namespace
__global__
void hip_fill_n(RandomAccessIterator f, N n, T value)
{
const uint32_t grid_dim = hipGridDim_x;
const uint32_t grid_dim = gridDim.x * blockDim.x;
size_t idx = hipBlockIdx_x * block_dim + hipThreadIdx_x;
size_t idx = blockIdx.x * block_dim + threadIdx.x;
while (idx < n) {
new (&f[idx]) T{value};
idx += grid_dim;