[HIPIFY][cuDNN] Add cudnnGetFilter4dDescriptor support

+ Update cudnn_convolution_forward test accordingly
Этот коммит содержится в:
Evgeny Mankov
2019-05-16 16:36:23 +03:00
родитель de7ec55bea
Коммит aed2affda2
3 изменённых файлов: 15 добавлений и 2 удалений
+1 -1
Просмотреть файл
@@ -293,7 +293,7 @@
|`cudnnScaleTensor` |`hipdnnScaleTensor` |
|`cudnnCreateFilterDescriptor` |`hipdnnCreateFilterDescriptor` |
|`cudnnSetFilter4dDescriptor` |`hipdnnSetFilter4dDescriptor` |
|`cudnnGetFilter4dDescriptor` | |
|`cudnnGetFilter4dDescriptor` |`hipdnnGetFilter4dDescriptor` |
|`cudnnSetFilterNdDescriptor` |`hipdnnSetFilterNdDescriptor` |
|`cudnnGetFilterNdDescriptor` |`hipdnnGetFilterNdDescriptor` |
|`cudnnDestroyFilterDescriptor` |`hipdnnDestroyFilterDescriptor` |
+1 -1
Просмотреть файл
@@ -76,7 +76,7 @@ const std::map<llvm::StringRef, hipCounter> CUDA_DNN_FUNCTION_MAP{
// cuDNN Filter functions
{"cudnnCreateFilterDescriptor", {"hipdnnCreateFilterDescriptor", "", CONV_LIB_FUNC, API_DNN}},
{"cudnnSetFilter4dDescriptor", {"hipdnnSetFilter4dDescriptor", "", CONV_LIB_FUNC, API_DNN}},
{"cudnnGetFilter4dDescriptor", {"hipdnnGetFilter4dDescriptor", "", CONV_LIB_FUNC, API_DNN, HIP_UNSUPPORTED}},
{"cudnnGetFilter4dDescriptor", {"hipdnnGetFilter4dDescriptor", "", CONV_LIB_FUNC, API_DNN}},
{"cudnnSetFilterNdDescriptor", {"hipdnnSetFilterNdDescriptor", "", CONV_LIB_FUNC, API_DNN}},
{"cudnnGetFilterNdDescriptor", {"hipdnnGetFilterNdDescriptor", "", CONV_LIB_FUNC, API_DNN}},
{"cudnnDestroyFilterDescriptor", {"hipdnnDestroyFilterDescriptor", "", CONV_LIB_FUNC, API_DNN}},
+13
Просмотреть файл
@@ -176,6 +176,19 @@ int main() {
out_desc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT,
out_n, out_c, out_h, out_w));
cudnnDataType_t *dataType = nullptr;
cudnnTensorFormat_t *tensorFormat = nullptr;
int *p_filt_k = nullptr;
int *p_filt_c = nullptr;
int *p_filt_h = nullptr;
int *p_filt_w = nullptr;
// CHECK: CUDNN_CALL(hipdnnGetFilter4dDescriptor(
CUDNN_CALL(cudnnGetFilter4dDescriptor(
filt_desc, dataType, tensorFormat,
p_filt_k, p_filt_c, p_filt_h, p_filt_w));
float *out_data;
// CHECK: CUDA_CALL(hipMalloc(
CUDA_CALL(cudaMalloc(