diff --git a/hipamd/src/hiprtc/hiprtc.cpp b/hipamd/src/hiprtc/hiprtc.cpp index 004836a712..78e3490e52 100644 --- a/hipamd/src/hiprtc/hiprtc.cpp +++ b/hipamd/src/hiprtc/hiprtc.cpp @@ -314,6 +314,11 @@ hiprtcResult hiprtcLinkAddFile(hiprtcLinkState hip_link_state, hiprtcJITInputTyp hiprtc::RTCLinkProgram* rtc_link_prog_ptr = reinterpret_cast(hip_link_state); + + if (!hiprtc::RTCLinkProgram::isLinkerValid(rtc_link_prog_ptr)) { + HIPRTC_RETURN(HIPRTC_ERROR_INVALID_INPUT); + } + if (!rtc_link_prog_ptr->AddLinkerFile(std::string(file_path), input_type)) { HIPRTC_RETURN(HIPRTC_ERROR_PROGRAM_CREATION_FAILURE); } @@ -344,6 +349,11 @@ hiprtcResult hiprtcLinkAddData(hiprtcLinkState hip_link_state, hiprtcJITInputTyp hiprtc::RTCLinkProgram* rtc_link_prog_ptr = reinterpret_cast(hip_link_state); + + if (!hiprtc::RTCLinkProgram::isLinkerValid(rtc_link_prog_ptr)) { + HIPRTC_RETURN(HIPRTC_ERROR_INVALID_INPUT); + } + if (!rtc_link_prog_ptr->AddLinkerData(image, image_size, input_name, input_type)) { HIPRTC_RETURN(HIPRTC_ERROR_PROGRAM_CREATION_FAILURE); } @@ -360,6 +370,11 @@ hiprtcResult hiprtcLinkComplete(hiprtcLinkState hip_link_state, void** bin_out, hiprtc::RTCLinkProgram* rtc_link_prog_ptr = reinterpret_cast(hip_link_state); + + if (!hiprtc::RTCLinkProgram::isLinkerValid(rtc_link_prog_ptr)) { + HIPRTC_RETURN(HIPRTC_ERROR_INVALID_INPUT); + } + if (!rtc_link_prog_ptr->LinkComplete(bin_out, size_out)) { HIPRTC_RETURN(HIPRTC_ERROR_LINKING); } @@ -372,7 +387,11 @@ hiprtcResult hiprtcLinkDestroy(hiprtcLinkState hip_link_state) { hiprtc::RTCLinkProgram* rtc_link_prog_ptr = reinterpret_cast(hip_link_state); - delete rtc_link_prog_ptr; + if (!hiprtc::RTCLinkProgram::isLinkerValid(rtc_link_prog_ptr)) { + HIPRTC_RETURN(HIPRTC_ERROR_INVALID_INPUT); + } + + delete rtc_link_prog_ptr; HIPRTC_RETURN(HIPRTC_SUCCESS); } diff --git a/hipamd/src/hiprtc/hiprtcInternal.cpp b/hipamd/src/hiprtc/hiprtcInternal.cpp index ac02aa1070..eaa9e4b80f 100644 --- a/hipamd/src/hiprtc/hiprtcInternal.cpp +++ b/hipamd/src/hiprtc/hiprtcInternal.cpp @@ -33,6 +33,7 @@ THE SOFTWARE. namespace hiprtc { using namespace helpers; +std::unordered_setRTCLinkProgram::linker_set_; std::vector getLinkOptions(const LinkArguments& args) { std::vector res; @@ -393,6 +394,16 @@ RTCLinkProgram::RTCLinkProgram(std::string name) : RTCProgram(name) { if (amd::Comgr::create_data_set(&link_input_) != AMD_COMGR_STATUS_SUCCESS) { crashWithMessage("Failed to allocate internal hiprtc structure"); } + amd::ScopedLock lock(lock_); + linker_set_.insert(this); +} + +bool RTCLinkProgram::isLinkerValid(RTCLinkProgram* link_program) { + amd::ScopedLock lock(lock_); + if (linker_set_.find(link_program) == linker_set_.end()) { + return false; + } + return true; } bool RTCLinkProgram::AddLinkerOptions(unsigned int num_options, hiprtcJIT_option* options_ptr, diff --git a/hipamd/src/hiprtc/hiprtcInternal.hpp b/hipamd/src/hiprtc/hiprtcInternal.hpp index aae8b64f47..f653e75d6f 100644 --- a/hipamd/src/hiprtc/hiprtcInternal.hpp +++ b/hipamd/src/hiprtc/hiprtcInternal.hpp @@ -271,13 +271,18 @@ class RTCLinkProgram : public RTCProgram { // Private Data Members amd_comgr_data_set_t link_input_; std::vector link_options_; + static std::unordered_set linker_set_; bool AddLinkerDataImpl(std::vector& link_data, hiprtcJITInputType input_type, std::string& link_file_name); public: RTCLinkProgram(std::string name); - ~RTCLinkProgram() { amd::Comgr::destroy_data_set(link_input_); } + ~RTCLinkProgram() { + amd::ScopedLock lock(lock_); + linker_set_.erase(this); + amd::Comgr::destroy_data_set(link_input_); + } // Public Member Functions bool AddLinkerOptions(unsigned int num_options, hiprtcJIT_option* options_ptr, void** options_vals_ptr); @@ -286,6 +291,7 @@ class RTCLinkProgram : public RTCProgram { hiprtcJITInputType input_type); bool LinkComplete(void** bin_out, size_t* size_out); void AppendLinkerOptions() { AppendOptions(HIPRTC_LINK_OPTIONS_APPEND, &link_options_); } + static bool isLinkerValid(RTCLinkProgram* link_program); }; // Thread Local Storage Variables Aggregator Class