[HIPIFY][cuDNN] Add cudnnGetFilter4dDescriptor support
+ Update cudnn_convolution_forward test accordingly
Этот коммит содержится в:
@@ -293,7 +293,7 @@
|
||||
|`cudnnScaleTensor` |`hipdnnScaleTensor` |
|
||||
|`cudnnCreateFilterDescriptor` |`hipdnnCreateFilterDescriptor` |
|
||||
|`cudnnSetFilter4dDescriptor` |`hipdnnSetFilter4dDescriptor` |
|
||||
|`cudnnGetFilter4dDescriptor` | |
|
||||
|`cudnnGetFilter4dDescriptor` |`hipdnnGetFilter4dDescriptor` |
|
||||
|`cudnnSetFilterNdDescriptor` |`hipdnnSetFilterNdDescriptor` |
|
||||
|`cudnnGetFilterNdDescriptor` |`hipdnnGetFilterNdDescriptor` |
|
||||
|`cudnnDestroyFilterDescriptor` |`hipdnnDestroyFilterDescriptor` |
|
||||
|
||||
@@ -76,7 +76,7 @@ const std::map<llvm::StringRef, hipCounter> CUDA_DNN_FUNCTION_MAP{
|
||||
// cuDNN Filter functions
|
||||
{"cudnnCreateFilterDescriptor", {"hipdnnCreateFilterDescriptor", "", CONV_LIB_FUNC, API_DNN}},
|
||||
{"cudnnSetFilter4dDescriptor", {"hipdnnSetFilter4dDescriptor", "", CONV_LIB_FUNC, API_DNN}},
|
||||
{"cudnnGetFilter4dDescriptor", {"hipdnnGetFilter4dDescriptor", "", CONV_LIB_FUNC, API_DNN, HIP_UNSUPPORTED}},
|
||||
{"cudnnGetFilter4dDescriptor", {"hipdnnGetFilter4dDescriptor", "", CONV_LIB_FUNC, API_DNN}},
|
||||
{"cudnnSetFilterNdDescriptor", {"hipdnnSetFilterNdDescriptor", "", CONV_LIB_FUNC, API_DNN}},
|
||||
{"cudnnGetFilterNdDescriptor", {"hipdnnGetFilterNdDescriptor", "", CONV_LIB_FUNC, API_DNN}},
|
||||
{"cudnnDestroyFilterDescriptor", {"hipdnnDestroyFilterDescriptor", "", CONV_LIB_FUNC, API_DNN}},
|
||||
|
||||
@@ -176,6 +176,19 @@ int main() {
|
||||
out_desc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT,
|
||||
out_n, out_c, out_h, out_w));
|
||||
|
||||
|
||||
cudnnDataType_t *dataType = nullptr;
|
||||
cudnnTensorFormat_t *tensorFormat = nullptr;
|
||||
int *p_filt_k = nullptr;
|
||||
int *p_filt_c = nullptr;
|
||||
int *p_filt_h = nullptr;
|
||||
int *p_filt_w = nullptr;
|
||||
|
||||
// CHECK: CUDNN_CALL(hipdnnGetFilter4dDescriptor(
|
||||
CUDNN_CALL(cudnnGetFilter4dDescriptor(
|
||||
filt_desc, dataType, tensorFormat,
|
||||
p_filt_k, p_filt_c, p_filt_h, p_filt_w));
|
||||
|
||||
float *out_data;
|
||||
// CHECK: CUDA_CALL(hipMalloc(
|
||||
CUDA_CALL(cudaMalloc(
|
||||
|
||||
Ссылка в новой задаче
Block a user