diff --git a/inc/roctracer.h b/inc/roctracer.h index fcd1aaf79c..b91037348e 100644 --- a/inc/roctracer.h +++ b/inc/roctracer.h @@ -87,7 +87,8 @@ typedef enum { ROCTRACER_STATUS_UNINIT = 2, ROCTRACER_STATUS_BREAK = 3, ROCTRACER_STATUS_BAD_DOMAIN = 4, - ROCTRACER_STATUS_HIP_API_ERR = 5, + ROCTRACER_STATUS_BAD_PARAMETER = 5, + ROCTRACER_STATUS_HIP_API_ERR = 6, } roctracer_status_t; //////////////////////////////////////////////////////////////////////////////// diff --git a/src/core/roctracer.cpp b/src/core/roctracer.cpp index a80e323934..704e1aba8e 100644 --- a/src/core/roctracer.cpp +++ b/src/core/roctracer.cpp @@ -396,7 +396,7 @@ PUBLIC_API roctracer_status_t roctracer_enable_api_callback( API_METHOD_PREFIX switch (domain) { case ROCTRACER_DOMAIN_ANY: - cid = 0; + if (cid != HIP_API_ID_ANY) HIP_EXC_RAISING(ROCTRACER_STATUS_BAD_PARAMETER, "DOMAIN_ANY and cid != HIP_API_ID_ANY"); case ROCTRACER_DOMAIN_HIP_API: { hipError_t hip_err = hipRegisterApiCallback(cid, callback, user_data); if (hip_err != hipSuccess) HIP_EXC_RAISING(ROCTRACER_STATUS_HIP_API_ERR, "hipRegisterApiCallback error(" << hip_err << ")"); @@ -416,7 +416,7 @@ PUBLIC_API roctracer_status_t roctracer_disable_api_callback( API_METHOD_PREFIX switch (domain) { case ROCTRACER_DOMAIN_ANY: - cid = 0; + if (cid != HIP_API_ID_ANY) HIP_EXC_RAISING(ROCTRACER_STATUS_BAD_PARAMETER, "DOMAIN_ANY and cid != HIP_API_ID_ANY"); case ROCTRACER_DOMAIN_HIP_API: { hipError_t hip_err = hipRemoveApiCallback(cid); if (hip_err != hipSuccess) HIP_EXC_RAISING(ROCTRACER_STATUS_HIP_API_ERR, "hipRemoveApiCallback error(" << hip_err << ")"); @@ -472,7 +472,7 @@ PUBLIC_API roctracer_status_t roctracer_enable_api_activity( if (pool == NULL) pool = roctracer_default_pool(); switch (domain) { case ROCTRACER_DOMAIN_ANY: - activity_kind = 0; + if (activity_kind != HIP_API_ID_ANY) HIP_EXC_RAISING(ROCTRACER_STATUS_BAD_PARAMETER, "DOMAIN_ANY and activity_kind != HIP_API_ID_ANY"); case ROCTRACER_DOMAIN_HIP_API: { const hipError_t hip_err = hipRegisterActivityCallback(activity_kind, roctracer::ActivityCallback, roctracer::ActivityAsyncCallback, pool); if (hip_err != hipSuccess) HIP_EXC_RAISING(ROCTRACER_STATUS_HIP_API_ERR, "hipRegisterActivityCallback error(" << hip_err << ")"); @@ -492,7 +492,7 @@ PUBLIC_API roctracer_status_t roctracer_disable_api_activity( API_METHOD_PREFIX switch (domain) { case ROCTRACER_DOMAIN_ANY: - activity_kind = 0; + if (activity_kind != HIP_API_ID_ANY) HIP_EXC_RAISING(ROCTRACER_STATUS_BAD_PARAMETER, "DOMAIN_ANY and activity_kind != HIP_API_ID_ANY"); case ROCTRACER_DOMAIN_HIP_API: { const hipError_t hip_err = hipRemoveActivityCallback(activity_kind); if (hip_err != hipSuccess) HIP_EXC_RAISING(ROCTRACER_STATUS_HIP_API_ERR, "hipRemoveActivityCallback error(" << hip_err << ")"); diff --git a/test/MatrixTranspose/MatrixTranspose.cpp b/test/MatrixTranspose/MatrixTranspose.cpp index 9819efe4f3..a2b14676a0 100644 --- a/test/MatrixTranspose/MatrixTranspose.cpp +++ b/test/MatrixTranspose/MatrixTranspose.cpp @@ -151,7 +151,7 @@ int main() { } while (0) // HIP API callback function -extern "C" void hip_api_callback( +void hip_api_callback( uint32_t domain, uint32_t cid, const void* callback_data, @@ -235,18 +235,18 @@ void init_tracing() { // Check tracer domains consitency ROCTRACER_CALL(roctracer_validate_domains()); // Enable HIP API callbacks - ROCTRACER_CALL(roctracer_enable_api_callback(ROCTRACER_DOMAIN_ANY, 0, hip_api_callback, NULL)); + ROCTRACER_CALL(roctracer_enable_api_callback(ROCTRACER_DOMAIN_ANY, HIP_API_ID_ANY, hip_api_callback, NULL)); // Enable HIP activity tracing roctracer_properties_t properties{}; properties.buffer_size = 12; properties.buffer_callback_fun = activity_callback; ROCTRACER_CALL(roctracer_open_pool(&properties)); - ROCTRACER_CALL(roctracer_enable_api_activity(ROCTRACER_DOMAIN_ANY, 0)); + ROCTRACER_CALL(roctracer_enable_api_activity(ROCTRACER_DOMAIN_ANY, HIP_API_ID_ANY)); } void finish_tracing() { - ROCTRACER_CALL(roctracer_disable_api_callback(ROCTRACER_DOMAIN_ANY, 0)); - ROCTRACER_CALL(roctracer_disable_api_activity(ROCTRACER_DOMAIN_ANY, 0)); + ROCTRACER_CALL(roctracer_disable_api_callback(ROCTRACER_DOMAIN_ANY, HIP_API_ID_ANY)); + ROCTRACER_CALL(roctracer_disable_api_activity(ROCTRACER_DOMAIN_ANY, HIP_API_ID_ANY)); ROCTRACER_CALL(roctracer_close_pool()); } #else