Fix memcpy2d kernel dims

This commit is contained in:
Rahul Garg
2018-05-24 17:00:12 +05:30
parent dc179e0c33
commit 981e56a68f
+14 -12
View File
@@ -1496,17 +1496,19 @@ __global__ void hip_copy2d_n(T* dst, const T* src, size_t width, size_t height,
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
size_t idy = blockIdx.y * blockDim.y + threadIdx.y;
size_t floorWidth = (width/sizeof(T));
if((idx < floorWidth)){
T *dstPtr = (T *)((uint8_t*) dst + idy * destPitch);
T *srcPtr = (T *)((uint8_t*) src + idy * srcPitch);
dstPtr[idx] = srcPtr[idx];
} else {
size_t bytesToCopy = width - (floorWidth * sizeof(T));
uint8_t *dstPtr = (uint8_t *) ((uint8_t*) dst + idy * destPitch);
uint8_t *srcPtr = (uint8_t *) ((uint8_t*) src + idy * srcPitch);
for(int i =0 ; i < bytesToCopy ; i++) {
dstPtr[idx+i]= srcPtr[idx+i];
}
if((idx < width) && (idy < height)) {
if((idx < floorWidth)){
T *dstPtr = (T *)((uint8_t*) dst + idy * destPitch);
T *srcPtr = (T *)((uint8_t*) src + idy * srcPitch);
dstPtr[idx] = srcPtr[idx];
} else {
size_t bytesToCopy = width - (floorWidth * sizeof(T));
uint8_t *dstPtr = (uint8_t *) ((uint8_t*) dst + idy * destPitch);
uint8_t *srcPtr = (uint8_t *) ((uint8_t*) src + idy * srcPitch);
for(int i =0 ; i < bytesToCopy ; i++) {
dstPtr[idx+i]= srcPtr[idx+i];
}
}
}
}
} // namespace
@@ -1524,7 +1526,7 @@ void ihipMemsetKernel(hipStream_t stream, T* ptr, T val, size_t sizeBytes) {
template <typename T>
void ihipMemcpy2dKernel(hipStream_t stream, T* dst, const T* src, size_t width, size_t height, size_t destPitch, size_t srcPitch) {
size_t threadsPerBlock = 16;
uint32_t grid_dim_x = clamp_integer<size_t>( ((width/sizeof(T))+(threadsPerBlock-1)) / threadsPerBlock, 1, UINT32_MAX);
uint32_t grid_dim_x = clamp_integer<size_t>( (width+(threadsPerBlock*sizeof(T)-1)) / (threadsPerBlock*sizeof(T)), 1, UINT32_MAX);
uint32_t grid_dim_y = clamp_integer<size_t>( (height+(threadsPerBlock-1)) / threadsPerBlock, 1, UINT32_MAX);
hipLaunchKernelGGL(hip_copy2d_n, dim3(grid_dim_x,grid_dim_y), dim3(threadsPerBlock,threadsPerBlock), 0u, stream, dst, src,
width, height, destPitch, srcPitch);