SWDEV-351980 - Remove HipApi{Callback|Activity}{Enable|Disable}Check

The code is easier to read if calling HIPActivityCallbackTracker
enable/disable_check directly. Both enable/disable_check return the
new mask, and the check whether a callback is already installed is
clearer.

Change-Id: Ic90d34489b5b4d9929dc08b4d9e93cc974b136b1
This commit is contained in:
Laurent Morichetti
2022-08-16 20:03:10 -07:00
rodzic 88c6e0a700
commit f0e082feb1
+26 -55
Wyświetl plik
@@ -186,29 +186,6 @@ class HIPActivityCallbackTracker {
static HIPActivityCallbackTracker hip_act_cb_tracker;
inline uint32_t HipApiCallbackEnableCheck(uint32_t op) {
const uint32_t mask = hip_act_cb_tracker.enable_check(op, API_CB_MASK);
const uint32_t ret = (mask & API_ACT_MASK);
return ret;
}
inline uint32_t HipApiCallbackDisableCheck(uint32_t op) {
const uint32_t mask = hip_act_cb_tracker.disable_check(op, API_CB_MASK);
const uint32_t ret = (mask & API_ACT_MASK);
return ret;
}
inline uint32_t HipApiActivityEnableCheck(uint32_t op) {
hip_act_cb_tracker.enable_check(op, API_ACT_MASK);
return 0;
}
inline uint32_t HipApiActivityDisableCheck(uint32_t op) {
const uint32_t mask = hip_act_cb_tracker.disable_check(op, API_ACT_MASK);
const uint32_t ret = (mask & API_CB_MASK);
return ret;
}
void HIP_ApiCallback(uint32_t op_id, roctracer_record_t* record, void* callback_data, void* arg) {
hip_api_data_t* data = static_cast<hip_api_data_t*>(callback_data);
MemoryPool* pool = static_cast<MemoryPool*>(arg);
@@ -522,17 +499,16 @@ static void roctracer_enable_callback_fun(roctracer_domain_t domain, uint32_t op
if (!HipLoader::Instance().Enabled()) break;
std::lock_guard lock(hip_activity_mutex);
hipError_t hip_err =
HipLoader::Instance().RegisterApiCallback(op, (void*)callback, user_data);
if (hip_err != hipSuccess)
FATAL_LOGGING("HIP::RegisterApiCallback(" << op << ") error(" << hip_err << ")");
if (hipError_t err =
HipLoader::Instance().RegisterApiCallback(op, (void*)callback, user_data);
err != hipSuccess)
FATAL_LOGGING("HIP::RegisterApiCallback(" << op << ") error(" << err << ")");
if (HipApiCallbackEnableCheck(op) == 0) {
hip_err =
HipLoader::Instance().RegisterActivityCallback(op, (void*)HIP_ApiCallback, nullptr);
if (hip_err != hipSuccess)
FATAL_LOGGING("HIPAPI: HIP::RegisterActivityCallback(" << op << ") error(" << hip_err
<< ")");
if ((hip_act_cb_tracker.enable_check(op, API_CB_MASK) & API_ACT_MASK) == 0) {
if (hipError_t err =
HipLoader::Instance().RegisterActivityCallback(op, (void*)HIP_ApiCallback, nullptr);
err != hipSuccess)
FATAL_LOGGING("HIPAPI: HIP::RegisterActivityCallback(" << op << ") error(" << err << ")");
}
break;
}
@@ -594,14 +570,12 @@ static void roctracer_disable_callback_fun(roctracer_domain_t domain, uint32_t o
if (!HipLoader::Instance().Enabled()) break;
std::lock_guard lock(hip_activity_mutex);
const hipError_t hip_err = HipLoader::Instance().RemoveApiCallback(op);
if (hip_err != hipSuccess)
FATAL_LOGGING("HIP::RemoveApiCallback(" << op << "), error(" << hip_err << ")");
if (hipError_t err = HipLoader::Instance().RemoveApiCallback(op); err != hipSuccess)
FATAL_LOGGING("HIP::RemoveApiCallback(" << op << "), error(" << err << ")");
if (HipApiCallbackDisableCheck(op) == 0) {
const hipError_t hip_err = HipLoader::Instance().RemoveActivityCallback(op);
if (hip_err != hipSuccess)
FATAL_LOGGING("HIPAPI: HIP::RemoveActivityCallback op(" << op << "), error(" << hip_err
if ((hip_act_cb_tracker.disable_check(op, API_CB_MASK) & API_ACT_MASK) == 0) {
if (hipError_t err = HipLoader::Instance().RemoveActivityCallback(op); err != hipSuccess)
FATAL_LOGGING("HIPAPI: HIP::RemoveActivityCallback op(" << op << "), error(" << err
<< ")");
}
break;
@@ -739,12 +713,11 @@ static void roctracer_enable_activity_fun(roctracer_domain_t domain, uint32_t op
if (!HipLoader::Instance().Enabled()) break;
std::lock_guard lock(hip_activity_mutex);
if (HipApiActivityEnableCheck(op) == 0) {
const hipError_t hip_err =
HipLoader::Instance().RegisterActivityCallback(op, (void*)HIP_ApiCallback, pool);
if (hip_err != hipSuccess)
FATAL_LOGGING("HIP::RegisterActivityCallback(" << op << " error(" << hip_err << ")");
}
hip_act_cb_tracker.enable_check(op, API_ACT_MASK);
if (hipError_t err =
HipLoader::Instance().RegisterActivityCallback(op, (void*)HIP_ApiCallback, pool);
err != hipSuccess)
FATAL_LOGGING("HIP::RegisterActivityCallback(" << op << " error(" << err << ")");
break;
}
case ACTIVITY_DOMAIN_ROCTX:
@@ -835,16 +808,14 @@ static void roctracer_disable_activity_fun(roctracer_domain_t domain, uint32_t o
if (!HipLoader::Instance().Enabled()) break;
std::lock_guard lock(hip_activity_mutex);
if (HipApiActivityDisableCheck(op) == 0) {
const hipError_t hip_err = HipLoader::Instance().RemoveActivityCallback(op);
if (hip_err != hipSuccess)
FATAL_LOGGING("HIP::RemoveActivityCallback op(" << op << "), error(" << hip_err << ")");
if ((hip_act_cb_tracker.disable_check(op, API_ACT_MASK) & API_CB_MASK) == 0) {
if (hipError_t err = HipLoader::Instance().RemoveActivityCallback(op); err != hipSuccess)
FATAL_LOGGING("HIP::RemoveActivityCallback op(" << op << "), error(" << err << ")");
} else {
const hipError_t hip_err =
HipLoader::Instance().RegisterActivityCallback(op, (void*)HIP_ApiCallback, nullptr);
if (hip_err != hipSuccess)
FATAL_LOGGING("HIPACT: HIP::RegisterActivityCallback(" << op << ") error(" << hip_err
<< ")");
if (hipError_t err =
HipLoader::Instance().RegisterActivityCallback(op, (void*)HIP_ApiCallback, nullptr);
err != hipSuccess)
FATAL_LOGGING("HIPACT: HIP::RegisterActivityCallback(" << op << ") error(" << err << ")");
}
break;
}