diff --git a/src/roctracer/roctracer.cpp b/src/roctracer/roctracer.cpp index fdc99ac8db..d77d600fdf 100644 --- a/src/roctracer/roctracer.cpp +++ b/src/roctracer/roctracer.cpp @@ -186,29 +186,6 @@ class HIPActivityCallbackTracker { static HIPActivityCallbackTracker hip_act_cb_tracker; -inline uint32_t HipApiCallbackEnableCheck(uint32_t op) { - const uint32_t mask = hip_act_cb_tracker.enable_check(op, API_CB_MASK); - const uint32_t ret = (mask & API_ACT_MASK); - return ret; -} - -inline uint32_t HipApiCallbackDisableCheck(uint32_t op) { - const uint32_t mask = hip_act_cb_tracker.disable_check(op, API_CB_MASK); - const uint32_t ret = (mask & API_ACT_MASK); - return ret; -} - -inline uint32_t HipApiActivityEnableCheck(uint32_t op) { - hip_act_cb_tracker.enable_check(op, API_ACT_MASK); - return 0; -} - -inline uint32_t HipApiActivityDisableCheck(uint32_t op) { - const uint32_t mask = hip_act_cb_tracker.disable_check(op, API_ACT_MASK); - const uint32_t ret = (mask & API_CB_MASK); - return ret; -} - void HIP_ApiCallback(uint32_t op_id, roctracer_record_t* record, void* callback_data, void* arg) { hip_api_data_t* data = static_cast(callback_data); MemoryPool* pool = static_cast(arg); @@ -522,17 +499,16 @@ static void roctracer_enable_callback_fun(roctracer_domain_t domain, uint32_t op if (!HipLoader::Instance().Enabled()) break; std::lock_guard lock(hip_activity_mutex); - hipError_t hip_err = - HipLoader::Instance().RegisterApiCallback(op, (void*)callback, user_data); - if (hip_err != hipSuccess) - FATAL_LOGGING("HIP::RegisterApiCallback(" << op << ") error(" << hip_err << ")"); + if (hipError_t err = + HipLoader::Instance().RegisterApiCallback(op, (void*)callback, user_data); + err != hipSuccess) + FATAL_LOGGING("HIP::RegisterApiCallback(" << op << ") error(" << err << ")"); - if (HipApiCallbackEnableCheck(op) == 0) { - hip_err = - HipLoader::Instance().RegisterActivityCallback(op, (void*)HIP_ApiCallback, nullptr); - if (hip_err != hipSuccess) - FATAL_LOGGING("HIPAPI: HIP::RegisterActivityCallback(" << op << ") error(" << hip_err - << ")"); + if ((hip_act_cb_tracker.enable_check(op, API_CB_MASK) & API_ACT_MASK) == 0) { + if (hipError_t err = + HipLoader::Instance().RegisterActivityCallback(op, (void*)HIP_ApiCallback, nullptr); + err != hipSuccess) + FATAL_LOGGING("HIPAPI: HIP::RegisterActivityCallback(" << op << ") error(" << err << ")"); } break; } @@ -594,14 +570,12 @@ static void roctracer_disable_callback_fun(roctracer_domain_t domain, uint32_t o if (!HipLoader::Instance().Enabled()) break; std::lock_guard lock(hip_activity_mutex); - const hipError_t hip_err = HipLoader::Instance().RemoveApiCallback(op); - if (hip_err != hipSuccess) - FATAL_LOGGING("HIP::RemoveApiCallback(" << op << "), error(" << hip_err << ")"); + if (hipError_t err = HipLoader::Instance().RemoveApiCallback(op); err != hipSuccess) + FATAL_LOGGING("HIP::RemoveApiCallback(" << op << "), error(" << err << ")"); - if (HipApiCallbackDisableCheck(op) == 0) { - const hipError_t hip_err = HipLoader::Instance().RemoveActivityCallback(op); - if (hip_err != hipSuccess) - FATAL_LOGGING("HIPAPI: HIP::RemoveActivityCallback op(" << op << "), error(" << hip_err + if ((hip_act_cb_tracker.disable_check(op, API_CB_MASK) & API_ACT_MASK) == 0) { + if (hipError_t err = HipLoader::Instance().RemoveActivityCallback(op); err != hipSuccess) + FATAL_LOGGING("HIPAPI: HIP::RemoveActivityCallback op(" << op << "), error(" << err << ")"); } break; @@ -739,12 +713,11 @@ static void roctracer_enable_activity_fun(roctracer_domain_t domain, uint32_t op if (!HipLoader::Instance().Enabled()) break; std::lock_guard lock(hip_activity_mutex); - if (HipApiActivityEnableCheck(op) == 0) { - const hipError_t hip_err = - HipLoader::Instance().RegisterActivityCallback(op, (void*)HIP_ApiCallback, pool); - if (hip_err != hipSuccess) - FATAL_LOGGING("HIP::RegisterActivityCallback(" << op << " error(" << hip_err << ")"); - } + hip_act_cb_tracker.enable_check(op, API_ACT_MASK); + if (hipError_t err = + HipLoader::Instance().RegisterActivityCallback(op, (void*)HIP_ApiCallback, pool); + err != hipSuccess) + FATAL_LOGGING("HIP::RegisterActivityCallback(" << op << " error(" << err << ")"); break; } case ACTIVITY_DOMAIN_ROCTX: @@ -835,16 +808,14 @@ static void roctracer_disable_activity_fun(roctracer_domain_t domain, uint32_t o if (!HipLoader::Instance().Enabled()) break; std::lock_guard lock(hip_activity_mutex); - if (HipApiActivityDisableCheck(op) == 0) { - const hipError_t hip_err = HipLoader::Instance().RemoveActivityCallback(op); - if (hip_err != hipSuccess) - FATAL_LOGGING("HIP::RemoveActivityCallback op(" << op << "), error(" << hip_err << ")"); + if ((hip_act_cb_tracker.disable_check(op, API_ACT_MASK) & API_CB_MASK) == 0) { + if (hipError_t err = HipLoader::Instance().RemoveActivityCallback(op); err != hipSuccess) + FATAL_LOGGING("HIP::RemoveActivityCallback op(" << op << "), error(" << err << ")"); } else { - const hipError_t hip_err = - HipLoader::Instance().RegisterActivityCallback(op, (void*)HIP_ApiCallback, nullptr); - if (hip_err != hipSuccess) - FATAL_LOGGING("HIPACT: HIP::RegisterActivityCallback(" << op << ") error(" << hip_err - << ")"); + if (hipError_t err = + HipLoader::Instance().RegisterActivityCallback(op, (void*)HIP_ApiCallback, nullptr); + err != hipSuccess) + FATAL_LOGGING("HIPACT: HIP::RegisterActivityCallback(" << op << ") error(" << err << ")"); } break; }