diff --git a/projects/clr/hipamd/src/hip_context.cpp b/projects/clr/hipamd/src/hip_context.cpp index d2f64691b7..f22ae22859 100644 --- a/projects/clr/hipamd/src/hip_context.cpp +++ b/projects/clr/hipamd/src/hip_context.cpp @@ -28,10 +28,7 @@ std::vector g_devices; namespace hip { - -thread_local Device* g_device = nullptr; -thread_local std::stack g_ctxtStack; -thread_local hipError_t g_lastError = hipSuccess; +thread_local TlsAggregator tls; Device* host_device = nullptr; //init() is only to be called from the HIP_INIT macro only once @@ -84,13 +81,13 @@ bool init() { } Device* getCurrentDevice() { - return g_device; + return tls.device_; } void setCurrentDevice(unsigned int index) { assert(indexdevices()[0]->getPreferredNumaNode(); + tls.device_ = g_devices[index]; + uint32_t preferredNumaNode = (tls.device_)->devices()[0]->getPreferredNumaNode(); amd::Os::setPreferredNumaNode(preferredNumaNode); } @@ -160,7 +157,7 @@ hipError_t hipCtxCreate(hipCtx_t *ctx, unsigned int flags, hipDevice_t device) // Increment ref count for device primary context g_devices[device]->retain(); - g_ctxtStack.push(g_devices[device]); + tls.ctxt_stack_.push(g_devices[device]); HIP_RETURN(hipSuccess); } @@ -169,15 +166,15 @@ hipError_t hipCtxSetCurrent(hipCtx_t ctx) { HIP_INIT_API(hipCtxSetCurrent, ctx); if (ctx == nullptr) { - if(!g_ctxtStack.empty()) { - g_ctxtStack.pop(); + if(!tls.ctxt_stack_.empty()) { + tls.ctxt_stack_.pop(); } } else { - hip::g_device = reinterpret_cast(ctx); - if(!g_ctxtStack.empty()) { - g_ctxtStack.pop(); + hip::tls.device_ = reinterpret_cast(ctx); + if(!tls.ctxt_stack_.empty()) { + tls.ctxt_stack_.pop(); } - g_ctxtStack.push(hip::getCurrentDevice()); + tls.ctxt_stack_.push(hip::getCurrentDevice()); } HIP_RETURN(hipSuccess); @@ -221,8 +218,8 @@ hipError_t hipCtxDestroy(hipCtx_t ctx) { } // Need to remove the ctx of calling thread if its the top one - if (!g_ctxtStack.empty() && g_ctxtStack.top() == dev) { - g_ctxtStack.pop(); + if (!tls.ctxt_stack_.empty() && tls.ctxt_stack_.top() == dev) { + tls.ctxt_stack_.pop(); } // Remove context from global context list @@ -240,11 +237,11 @@ hipError_t hipCtxPopCurrent(hipCtx_t* ctx) { HIP_INIT_API(hipCtxPopCurrent, ctx); hip::Device** dev = reinterpret_cast(ctx); - if (!g_ctxtStack.empty()) { + if (!tls.ctxt_stack_.empty()) { if (dev != nullptr) { - *dev = g_ctxtStack.top(); + *dev = tls.ctxt_stack_.top(); } - g_ctxtStack.pop(); + tls.ctxt_stack_.pop(); } else { DevLogError("Context Stack empty \n"); HIP_RETURN(hipErrorInvalidContext); @@ -261,8 +258,8 @@ hipError_t hipCtxPushCurrent(hipCtx_t ctx) { HIP_RETURN(hipErrorInvalidContext); } - hip::g_device = dev; - g_ctxtStack.push(hip::getCurrentDevice()); + hip::tls.device_ = dev; + tls.ctxt_stack_.push(hip::getCurrentDevice()); HIP_RETURN(hipSuccess); } diff --git a/projects/clr/hipamd/src/hip_error.cpp b/projects/clr/hipamd/src/hip_error.cpp index cf6a2922dd..faa56de5d0 100644 --- a/projects/clr/hipamd/src/hip_error.cpp +++ b/projects/clr/hipamd/src/hip_error.cpp @@ -25,15 +25,15 @@ hipError_t hipGetLastError() { HIP_INIT_API(hipGetLastError); - hipError_t err = hip::g_lastError; - hip::g_lastError = hipSuccess; + hipError_t err = hip::tls.last_error_; + hip::tls.last_error_ = hipSuccess; return err; } hipError_t hipPeekAtLastError() { HIP_INIT_API(hipPeekAtLastError); - hipError_t err = hip::g_lastError; + hipError_t err = hip::tls.last_error_; HIP_RETURN(err); } diff --git a/projects/clr/hipamd/src/hip_graph.cpp b/projects/clr/hipamd/src/hip_graph.cpp index 256df07408..311b0cadb1 100644 --- a/projects/clr/hipamd/src/hip_graph.cpp +++ b/projects/clr/hipamd/src/hip_graph.cpp @@ -27,8 +27,6 @@ std::vector g_captureStreams; amd::Monitor g_captureStreamsLock{"StreamCaptureGlobalList"}; -thread_local std::vector l_captureStreams; -thread_local hipStreamCaptureMode l_streamCaptureMode{hipStreamCaptureModeGlobal}; inline hipError_t ihipGraphAddNode(hipGraphNode_t graphNode, hipGraph_t graph, const hipGraphNode_t* pDependencies, size_t numDependencies) { @@ -835,8 +833,8 @@ hipError_t hipThreadExchangeStreamCaptureMode(hipStreamCaptureMode* mode) { HIP_RETURN(hipErrorInvalidValue); } - auto oldMode = l_streamCaptureMode; - l_streamCaptureMode = *mode; + auto oldMode = hip::tls.stream_capture_mode_; + hip::tls.stream_capture_mode_ = *mode; *mode = oldMode; HIP_RETURN_DURATION(hipSuccess); @@ -864,7 +862,7 @@ hipError_t hipStreamBeginCapture_common(hipStream_t stream, hipStreamCaptureMode s->SetCaptureMode(mode); s->SetOriginStream(); if (mode != hipStreamCaptureModeRelaxed) { - l_captureStreams.push_back(s); + hip::tls.capture_streams_.push_back(s); } if (mode == hipStreamCaptureModeGlobal) { amd::ScopedLock lock(g_captureStreamsLock); @@ -905,12 +903,13 @@ hipError_t hipStreamEndCapture_common(hipStream_t stream, hipGraph_t* pGraph) { } // If mode is not hipStreamCaptureModeRelaxed, hipStreamEndCapture must be called on the stream // from the same thread - const auto& it = std::find(l_captureStreams.begin(), l_captureStreams.end(), s); + const auto& it = std::find(hip::tls.capture_streams_.begin(), + hip::tls.capture_streams_.end(), s); if (s->GetCaptureMode() != hipStreamCaptureModeRelaxed) { - if (it == l_captureStreams.end()) { + if (it == hip::tls.capture_streams_.end()) { return hipErrorStreamCaptureWrongThread; } - l_captureStreams.erase(it); + hip::tls.capture_streams_.erase(it); } if (s->GetCaptureMode() == hipStreamCaptureModeGlobal) { amd::ScopedLock lock(g_captureStreamsLock); diff --git a/projects/clr/hipamd/src/hip_internal.hpp b/projects/clr/hipamd/src/hip_internal.hpp index c1a2671d3f..b7795b9842 100644 --- a/projects/clr/hipamd/src/hip_internal.hpp +++ b/projects/clr/hipamd/src/hip_internal.hpp @@ -82,8 +82,8 @@ static amd::Monitor g_hipInitlock{"hipInit lock"}; HIP_RETURN(hipErrorInvalidDevice); \ } \ } \ - if (hip::g_device == nullptr && g_devices.size() > 0) { \ - hip::g_device = g_devices[0]; \ + if (hip::tls.device_ == nullptr && g_devices.size() > 0) { \ + hip::tls.device_ = g_devices[0]; \ amd::Os::setPreferredNumaNode(g_devices[0]->devices()[0]->getPreferredNumaNode()); \ } \ } @@ -93,8 +93,8 @@ static amd::Monitor g_hipInitlock{"hipInit lock"}; if (!amd::Runtime::initialized()) { \ if (hip::init()) {} \ } \ - if (hip::g_device == nullptr && g_devices.size() > 0) { \ - hip::g_device = g_devices[0]; \ + if (hip::tls.device_ == nullptr && g_devices.size() > 0) { \ + hip::tls.device_ = g_devices[0]; \ amd::Os::setPreferredNumaNode(g_devices[0]->devices()[0]->getPreferredNumaNode()); \ } \ } @@ -130,17 +130,17 @@ static amd::Monitor g_hipInitlock{"hipInit lock"}; HIP_INIT_API_INTERNAL(1, cid, __VA_ARGS__) #define HIP_RETURN_DURATION(ret, ...) \ - hip::g_lastError = ret; \ + hip::tls.last_error_ = ret; \ HIPPrintDuration(amd::LOG_INFO, amd::LOG_API, &startTimeUs, \ "%s: Returned %s : %s", \ - __func__, ihipGetErrorName(hip::g_lastError), \ + __func__, ihipGetErrorName(hip::tls.last_error_), \ ToString( __VA_ARGS__ ).c_str()); \ - return hip::g_lastError; + return hip::tls.last_error_; #define HIP_RETURN(ret, ...) \ - hip::g_lastError = ret; \ - HIP_ERROR_PRINT(hip::g_lastError, __VA_ARGS__) \ - return hip::g_lastError; + hip::tls.last_error_ = ret; \ + HIP_ERROR_PRINT(hip::tls.last_error_, __VA_ARGS__) \ + return hip::tls.last_error_; #define HIP_RETURN_ONFAIL(func) \ do { \ @@ -161,12 +161,12 @@ static amd::Monitor g_hipInitlock{"hipInit lock"}; } while (0); #define CHECK_STREAM_CAPTURE_SUPPORTED() \ - if (l_streamCaptureMode == hipStreamCaptureModeThreadLocal) { \ - if (l_captureStreams.size() != 0) { \ + if (hip::tls.stream_capture_mode_ == hipStreamCaptureModeThreadLocal) { \ + if (hip::tls.capture_streams_.size() != 0) { \ HIP_RETURN(hipErrorStreamCaptureUnsupported); \ } \ - } else if (l_streamCaptureMode == hipStreamCaptureModeGlobal) { \ - if (l_captureStreams.size() != 0) { \ + } else if (hip::tls.stream_capture_mode_ == hipStreamCaptureModeGlobal) { \ + if (hip::tls.capture_streams_.size() != 0) { \ HIP_RETURN(hipErrorStreamCaptureUnsupported); \ } \ amd::ScopedLock lock(g_captureStreamsLock); \ @@ -205,6 +205,26 @@ namespace hc { class accelerator; class accelerator_view; }; + +struct ihipExec_t { + dim3 gridDim_; + dim3 blockDim_; + size_t sharedMem_; + hipStream_t hStream_; + std::vector arguments_; +}; + +class stream_per_thread { +private: + std::vector m_streams; +public: + stream_per_thread(); + stream_per_thread(const stream_per_thread& ) = delete; + void operator=(const stream_per_thread& ) = delete; + ~stream_per_thread(); + hipStream_t get(); +}; + namespace hip { class Device; class MemoryPool; @@ -449,9 +469,26 @@ namespace hip { void RemoveStreamFromPools(Stream* stream); }; - /// Current thread's device - extern thread_local Device* g_device; - extern thread_local hipError_t g_lastError; + /// Thread Local Storage Variables Aggregator Class + class TlsAggregator { + public: + Device* device_; + std::stack ctxt_stack_; + hipError_t last_error_; + std::vector capture_streams_; + hipStreamCaptureMode stream_capture_mode_; + std::stack exec_stack_; + stream_per_thread stream_per_thread_obj_; + + TlsAggregator(): device_(nullptr), + last_error_(hipSuccess), + stream_capture_mode_(hipStreamCaptureModeGlobal) { + } + ~TlsAggregator() { + } + }; + extern thread_local TlsAggregator tls; + /// Device representing the host - for pinned memory extern Device* host_device; @@ -475,16 +512,9 @@ namespace hip { extern bool isValid(hipStream_t& stream); extern amd::Monitor hipArraySetLock; extern std::unordered_set hipArraySet; -}; +}; // namespace hip extern void WaitThenDecrementSignal(hipStream_t stream, hipError_t status, void* user_data); -struct ihipExec_t { - dim3 gridDim_; - dim3 blockDim_; - size_t sharedMem_; - hipStream_t hStream_; - std::vector arguments_; -}; /// Wait all active streams on the blocking queue. The method enqueues a wait command and /// doesn't stall the current thread @@ -515,6 +545,4 @@ constexpr bool kMarkerDisableFlush = true; //!< Avoids command batch flush in extern std::vector g_captureStreams; extern amd::Monitor g_captureStreamsLock; -extern thread_local std::vector l_captureStreams; -extern thread_local hipStreamCaptureMode l_streamCaptureMode; #endif // HIP_SRC_HIP_INTERNAL_H diff --git a/projects/clr/hipamd/src/hip_platform.cpp b/projects/clr/hipamd/src/hip_platform.cpp index d92737e9b4..6a16d45ed2 100644 --- a/projects/clr/hipamd/src/hip_platform.cpp +++ b/projects/clr/hipamd/src/hip_platform.cpp @@ -29,7 +29,6 @@ constexpr unsigned __hipFatMAGIC2 = 0x48495046; // "HIPF" -thread_local std::stack execStack_; PlatformState* PlatformState::platform_; // Initiaized as nullptr by default // forward declaration of methods required for __hipRegisrterManagedVar @@ -920,7 +919,7 @@ hipError_t PlatformState::initStatManagedVarDevicePtr(int deviceId) { } void PlatformState::setupArgument(const void* arg, size_t size, size_t offset) { - auto& arguments = execStack_.top().arguments_; + auto& arguments = hip::tls.exec_stack_.top().arguments_; if (arguments.size() < offset + size) { arguments.resize(offset + size); @@ -931,10 +930,10 @@ void PlatformState::setupArgument(const void* arg, size_t size, size_t offset) { void PlatformState::configureCall(dim3 gridDim, dim3 blockDim, size_t sharedMem, hipStream_t stream) { - execStack_.push(ihipExec_t{gridDim, blockDim, sharedMem, stream}); + hip::tls.exec_stack_.push(ihipExec_t{gridDim, blockDim, sharedMem, stream}); } void PlatformState::popExec(ihipExec_t& exec) { - exec = std::move(execStack_.top()); - execStack_.pop(); + exec = std::move(hip::tls.exec_stack_.top()); + hip::tls.exec_stack_.pop(); } diff --git a/projects/clr/hipamd/src/hip_stream.cpp b/projects/clr/hipamd/src/hip_stream.cpp index 72708c93b6..e67e831dad 100644 --- a/projects/clr/hipamd/src/hip_stream.cpp +++ b/projects/clr/hipamd/src/hip_stream.cpp @@ -281,55 +281,50 @@ static hipError_t ihipStreamCreate(hipStream_t* stream, } // ================================================================================================ -class stream_per_thread { -private: - std::vector m_streams; -public: - stream_per_thread() { + +stream_per_thread::stream_per_thread() { + m_streams.resize(g_devices.size()); + for (auto &stream : m_streams) { + stream = nullptr; + } +} + +stream_per_thread::~stream_per_thread() { + for (auto &stream:m_streams) { + if (stream != nullptr && hip::isValid(stream)) { + delete reinterpret_cast(stream); + stream = nullptr; + } + } +} + +hipStream_t stream_per_thread::get() { + hip::Device* device = hip::getCurrentDevice(); + int currDev = device->deviceId(); + // This is to make sure m_streams is not empty + if (m_streams.empty()) { m_streams.resize(g_devices.size()); for (auto &stream : m_streams) { stream = nullptr; } } - stream_per_thread(const stream_per_thread& ) = delete; - void operator=(const stream_per_thread& ) = delete; - ~stream_per_thread() { - for (auto &stream:m_streams) { - if (stream != nullptr && hip::isValid(stream)) { - delete reinterpret_cast(stream); - stream = nullptr; - } + // There is a scenario where hipResetDevice destroys stream per thread + // hence isValid check is required to make sure only valid stream is used + if (m_streams[currDev] == nullptr || !hip::isValid(m_streams[currDev])) { + hipError_t status = ihipStreamCreate(&m_streams[currDev], hipStreamDefault, + hip::Stream::Priority::Normal); + if (status != hipSuccess) { + DevLogError("Stream creation failed\n"); } } + return m_streams[currDev]; +} - hipStream_t get() { - hip::Device* device = hip::getCurrentDevice(); - int currDev = device->deviceId(); - // This is to make sure m_streams is not empty - if (m_streams.empty()) { - m_streams.resize(g_devices.size()); - for (auto &stream : m_streams) { - stream = nullptr; - } - } - // There is a scenario where hipResetDevice destroys stream per thread - // hence isValid check is required to make sure only valid stream is used - if (m_streams[currDev] == nullptr || !hip::isValid(m_streams[currDev])) { - hipError_t status = ihipStreamCreate(&m_streams[currDev], hipStreamDefault, - hip::Stream::Priority::Normal); - if (status != hipSuccess) { - DevLogError("Stream creation failed\n"); - } - } - return m_streams[currDev]; - } -}; -thread_local stream_per_thread streamPerThreadObj; // ================================================================================================ void getStreamPerThread(hipStream_t& stream) { if (stream == hipStreamPerThread) { - stream = streamPerThreadObj.get(); + stream = hip::tls.stream_per_thread_obj_.get(); } } @@ -469,9 +464,10 @@ hipError_t hipStreamDestroy(hipStream_t stream) { if (g_it != g_captureStreams.end()) { g_captureStreams.erase(g_it); } - const auto& l_it = std::find(l_captureStreams.begin(), l_captureStreams.end(), s); - if (l_it != l_captureStreams.end()) { - l_captureStreams.erase(l_it); + const auto& l_it = std::find(hip::tls.capture_streams_.begin(), + hip::tls.capture_streams_.end(), s); + if (l_it != hip::tls.capture_streams_.end()) { + hip::tls.capture_streams_.erase(l_it); } delete s; diff --git a/projects/clr/hipamd/src/hiprtc/hiprtc.cpp b/projects/clr/hipamd/src/hiprtc/hiprtc.cpp index fb09d88434..7709cc86b1 100644 --- a/projects/clr/hipamd/src/hiprtc/hiprtc.cpp +++ b/projects/clr/hipamd/src/hiprtc/hiprtc.cpp @@ -24,7 +24,7 @@ THE SOFTWARE. #include "hiprtcInternal.hpp" namespace hiprtc { -thread_local hiprtcResult g_lastRtcError = HIPRTC_SUCCESS; +thread_local TlsAggregator tls; } const char* hiprtcGetErrorString(hiprtcResult x) { diff --git a/projects/clr/hipamd/src/hiprtc/hiprtcInternal.hpp b/projects/clr/hipamd/src/hiprtc/hiprtcInternal.hpp index 3a6d2f867a..9e0d831b3f 100644 --- a/projects/clr/hipamd/src/hiprtc/hiprtcInternal.hpp +++ b/projects/clr/hipamd/src/hiprtc/hiprtcInternal.hpp @@ -72,10 +72,10 @@ static amd::Monitor g_hiprtcInitlock {"hiprtcInit lock"}; hiprtc::internal::ToString(__VA_ARGS__).c_str()); #define HIPRTC_RETURN(ret) \ - hiprtc::g_lastRtcError = (ret); \ + hiprtc::tls.last_rtc_error_ = (ret); \ ClPrint(amd::LOG_INFO, amd::LOG_API, "%s: Returned %s", __func__, \ - hiprtcGetErrorString(hiprtc::g_lastRtcError)); \ - return hiprtc::g_lastRtcError; + hiprtcGetErrorString(hiprtc::tls.last_rtc_error_)); \ + return hiprtc::tls.last_rtc_error_; namespace hiprtc { @@ -106,9 +106,9 @@ protected: // Member Functions bool findIsa(); - - // Data Members - std::string name_; + + // Data Members + std::string name_; std::string isa_; std::string build_log_; std::vector executable_; @@ -126,7 +126,7 @@ class RTCCompileProgram : public RTCProgram { std::string source_name_; std::map stripped_names_; std::map demangled_names_; - + std::vector compile_options_; std::vector link_options_; @@ -235,4 +235,15 @@ public: bool LinkComplete(void** bin_out, size_t* size_out); }; +// Thread Local Storage Variables Aggregator Class +class TlsAggregator { +public: + hiprtcResult last_rtc_error_; + + TlsAggregator(): last_rtc_error_(HIPRTC_SUCCESS) { + } + ~TlsAggregator() { + } +}; +extern thread_local TlsAggregator tls; } // namespace hiprtc diff --git a/projects/clr/hipamd/src/hiprtc_internal.hpp b/projects/clr/hipamd/src/hiprtc_internal.hpp index 55fbff8200..0d50300459 100644 --- a/projects/clr/hipamd/src/hiprtc_internal.hpp +++ b/projects/clr/hipamd/src/hiprtc_internal.hpp @@ -56,10 +56,10 @@ extern "C" char * __cxa_demangle(const char *mangled_name, char *output_buffer, HIP_INIT_VOID(); #define HIPRTC_RETURN(ret) \ - hiprtc::g_lastRtcError = ret; \ + hiprtc::tls.last_rtc_error_ = ret; \ ClPrint(amd::LOG_INFO, amd::LOG_API, "%s: Returned %s", __func__, \ - hiprtcGetErrorString(hiprtc::g_lastRtcError)); \ - return hiprtc::g_lastRtcError; + hiprtcGetErrorString(hiprtc::tls.last_rtc_error_)); \ + return hiprtc::tls.last_rtc_error_; #endif // HIPRTC_SRC_HIP_INTERNAL_H