diff --git a/src/core/loader.h b/src/core/loader.h index 1d85a31787..43b07be905 100644 --- a/src/core/loader.h +++ b/src/core/loader.h @@ -48,7 +48,7 @@ class BaseLoader : public T { BaseLoader() { const int flags = (to_load_ == true) ? RTLD_LAZY : RTLD_LAZY|RTLD_NOLOAD; handle_ = dlopen(lib_name_, flags); - if (handle_ == NULL) { + if ((to_check_ == true) && (handle_ == NULL)) { fprintf(stderr, "roctracer: Loading '%s' failed, %s\n", lib_name_, dlerror()); abort(); } @@ -62,6 +62,7 @@ class BaseLoader : public T { } static bool to_load_; + static bool to_check_; static mutex_t mutex_; static const char* lib_name_; @@ -174,8 +175,11 @@ typedef BaseLoader RocTxLoader; template typename roctracer::BaseLoader::mutex_t roctracer::BaseLoader::mutex_; \ template std::atomic*> roctracer::BaseLoader::instance_{}; \ template bool roctracer::BaseLoader::to_load_ = false; \ + template bool roctracer::BaseLoader::to_check_ = true; \ template<> const char* roctracer::HipLoader::lib_name_ = "libhip_hcc.so"; \ + template<> bool roctracer::HipLoader::to_check_ = false; \ template<> const char* roctracer::HccLoader::lib_name_ = "libmcwamp.so"; \ + template<> bool roctracer::HccLoader::to_check_ = false; \ template<> const char* roctracer::KfdLoader::lib_name_ = "libkfdwrapper64.so"; \ template<> const char* roctracer::RocTxLoader::lib_name_ = "libroctx64.so"; \ template<> bool roctracer::RocTxLoader::to_load_ = true; diff --git a/src/core/roctracer.cpp b/src/core/roctracer.cpp index 07a998db11..eea72bf12d 100644 --- a/src/core/roctracer.cpp +++ b/src/core/roctracer.cpp @@ -640,6 +640,8 @@ static roctracer_status_t roctracer_enable_callback_fun( } case ACTIVITY_DOMAIN_HCC_OPS: break; case ACTIVITY_DOMAIN_HIP_API: { + if (roctracer::HipLoader::Instance().Enabled() == false) break; + hipError_t hip_err = roctracer::HipLoader::Instance().RegisterApiCallback(op, (void*)callback, user_data); if (hip_err != hipSuccess) HIP_EXC_RAISING(ROCTRACER_STATUS_HIP_API_ERR, "hipRegisterApiCallback(" << op << ") error(" << hip_err << ")"); break; @@ -716,6 +718,8 @@ static roctracer_status_t roctracer_disable_callback_fun( case ACTIVITY_DOMAIN_HSA_API: break; case ACTIVITY_DOMAIN_HCC_OPS: break; case ACTIVITY_DOMAIN_HIP_API: { + if (roctracer::HipLoader::Instance().Enabled() == false) break; + hipError_t hip_err = roctracer::HipLoader::Instance().RemoveApiCallback(op); if (hip_err != hipSuccess) HIP_EXC_RAISING(ROCTRACER_STATUS_HIP_API_ERR, "hipRemoveApiCallback error(" << hip_err << ")"); break; @@ -821,7 +825,10 @@ static roctracer_status_t roctracer_enable_activity_fun( case ACTIVITY_DOMAIN_HSA_API: break; case ACTIVITY_DOMAIN_KFD_API: break; case ACTIVITY_DOMAIN_HCC_OPS: { - if (roctracer::HccLoader::GetRef() == NULL) { + const bool init_phase = (roctracer::HccLoader::GetRef() == NULL); + if (roctracer::HccLoader::Instance().Enabled() == false) break; + + if (init_phase == true) { if (getenv("ROCP_HCC_CORRID_WAIT") != NULL) { roctracer::correlation_id_wait = true; fprintf(stdout, "roctracer: HCC correlation ID wait enabled\n"); fflush(stdout); @@ -839,6 +846,8 @@ static roctracer_status_t roctracer_enable_activity_fun( break; } case ACTIVITY_DOMAIN_HIP_API: { + if (roctracer::HipLoader::Instance().Enabled() == false) break; + const hipError_t hip_err = roctracer::HipLoader::Instance().RegisterActivityCallback(op, (void*)roctracer::HIP_SyncActivityCallback, (void*)pool); if (hip_err != hipSuccess) HIP_EXC_RAISING(ROCTRACER_STATUS_HIP_API_ERR, "hipRegisterActivityCallback error(" << hip_err << ")"); break; @@ -904,11 +913,15 @@ static roctracer_status_t roctracer_disable_activity_fun( case ACTIVITY_DOMAIN_HSA_API: break; case ACTIVITY_DOMAIN_KFD_API: break; case ACTIVITY_DOMAIN_HCC_OPS: { + if (roctracer::HccLoader::Instance().Enabled() == false) break; + const bool succ = roctracer::HccLoader::Instance().EnableActivityCallback(op, false); if (succ == false) HCC_EXC_RAISING(ROCTRACER_STATUS_HCC_OPS_ERR, "HCC::EnableActivityCallback(NULL) error domain(" << domain << ") op(" << op << ")"); break; } case ACTIVITY_DOMAIN_HIP_API: { + if (roctracer::HipLoader::Instance().Enabled() == false) break; + const hipError_t hip_err = roctracer::HipLoader::Instance().RemoveActivityCallback(op); if (hip_err != hipSuccess) HIP_EXC_RAISING(ROCTRACER_STATUS_HIP_API_ERR, "hipRemoveActivityCallback error(" << hip_err << ")"); break;