SWDEV-351969 - TLS Optimization

- Aggregate all TLS(Thread Local Storage) variables into a single class
- This is to improve cache accesses per thread

Change-Id: Ic8361eaeae290fff00254684e309471958365eb9


[ROCm/clr commit: 8b391ef18c]
Этот коммит содержится в:
Rakesh Roy
2022-09-26 15:59:27 +05:30
коммит произвёл Rakesh Roy
родитель 2be45a82ec
Коммит f149b21399
9 изменённых файлов: 145 добавлений и 115 удалений
+18 -21
Просмотреть файл
@@ -28,10 +28,7 @@
std::vector<hip::Device*> g_devices;
namespace hip {
thread_local Device* g_device = nullptr;
thread_local std::stack<Device*> g_ctxtStack;
thread_local hipError_t g_lastError = hipSuccess;
thread_local TlsAggregator tls;
Device* host_device = nullptr;
//init() is only to be called from the HIP_INIT macro only once
@@ -84,13 +81,13 @@ bool init() {
}
Device* getCurrentDevice() {
return g_device;
return tls.device_;
}
void setCurrentDevice(unsigned int index) {
assert(index<g_devices.size());
g_device = g_devices[index];
uint32_t preferredNumaNode = g_device->devices()[0]->getPreferredNumaNode();
tls.device_ = g_devices[index];
uint32_t preferredNumaNode = (tls.device_)->devices()[0]->getPreferredNumaNode();
amd::Os::setPreferredNumaNode(preferredNumaNode);
}
@@ -160,7 +157,7 @@ hipError_t hipCtxCreate(hipCtx_t *ctx, unsigned int flags, hipDevice_t device)
// Increment ref count for device primary context
g_devices[device]->retain();
g_ctxtStack.push(g_devices[device]);
tls.ctxt_stack_.push(g_devices[device]);
HIP_RETURN(hipSuccess);
}
@@ -169,15 +166,15 @@ hipError_t hipCtxSetCurrent(hipCtx_t ctx) {
HIP_INIT_API(hipCtxSetCurrent, ctx);
if (ctx == nullptr) {
if(!g_ctxtStack.empty()) {
g_ctxtStack.pop();
if(!tls.ctxt_stack_.empty()) {
tls.ctxt_stack_.pop();
}
} else {
hip::g_device = reinterpret_cast<hip::Device*>(ctx);
if(!g_ctxtStack.empty()) {
g_ctxtStack.pop();
hip::tls.device_ = reinterpret_cast<hip::Device*>(ctx);
if(!tls.ctxt_stack_.empty()) {
tls.ctxt_stack_.pop();
}
g_ctxtStack.push(hip::getCurrentDevice());
tls.ctxt_stack_.push(hip::getCurrentDevice());
}
HIP_RETURN(hipSuccess);
@@ -221,8 +218,8 @@ hipError_t hipCtxDestroy(hipCtx_t ctx) {
}
// Need to remove the ctx of calling thread if its the top one
if (!g_ctxtStack.empty() && g_ctxtStack.top() == dev) {
g_ctxtStack.pop();
if (!tls.ctxt_stack_.empty() && tls.ctxt_stack_.top() == dev) {
tls.ctxt_stack_.pop();
}
// Remove context from global context list
@@ -240,11 +237,11 @@ hipError_t hipCtxPopCurrent(hipCtx_t* ctx) {
HIP_INIT_API(hipCtxPopCurrent, ctx);
hip::Device** dev = reinterpret_cast<hip::Device**>(ctx);
if (!g_ctxtStack.empty()) {
if (!tls.ctxt_stack_.empty()) {
if (dev != nullptr) {
*dev = g_ctxtStack.top();
*dev = tls.ctxt_stack_.top();
}
g_ctxtStack.pop();
tls.ctxt_stack_.pop();
} else {
DevLogError("Context Stack empty \n");
HIP_RETURN(hipErrorInvalidContext);
@@ -261,8 +258,8 @@ hipError_t hipCtxPushCurrent(hipCtx_t ctx) {
HIP_RETURN(hipErrorInvalidContext);
}
hip::g_device = dev;
g_ctxtStack.push(hip::getCurrentDevice());
hip::tls.device_ = dev;
tls.ctxt_stack_.push(hip::getCurrentDevice());
HIP_RETURN(hipSuccess);
}
+3 -3
Просмотреть файл
@@ -25,15 +25,15 @@
hipError_t hipGetLastError()
{
HIP_INIT_API(hipGetLastError);
hipError_t err = hip::g_lastError;
hip::g_lastError = hipSuccess;
hipError_t err = hip::tls.last_error_;
hip::tls.last_error_ = hipSuccess;
return err;
}
hipError_t hipPeekAtLastError()
{
HIP_INIT_API(hipPeekAtLastError);
hipError_t err = hip::g_lastError;
hipError_t err = hip::tls.last_error_;
HIP_RETURN(err);
}
+7 -8
Просмотреть файл
@@ -27,8 +27,6 @@
std::vector<hip::Stream*> g_captureStreams;
amd::Monitor g_captureStreamsLock{"StreamCaptureGlobalList"};
thread_local std::vector<hip::Stream*> l_captureStreams;
thread_local hipStreamCaptureMode l_streamCaptureMode{hipStreamCaptureModeGlobal};
inline hipError_t ihipGraphAddNode(hipGraphNode_t graphNode, hipGraph_t graph,
const hipGraphNode_t* pDependencies, size_t numDependencies) {
@@ -835,8 +833,8 @@ hipError_t hipThreadExchangeStreamCaptureMode(hipStreamCaptureMode* mode) {
HIP_RETURN(hipErrorInvalidValue);
}
auto oldMode = l_streamCaptureMode;
l_streamCaptureMode = *mode;
auto oldMode = hip::tls.stream_capture_mode_;
hip::tls.stream_capture_mode_ = *mode;
*mode = oldMode;
HIP_RETURN_DURATION(hipSuccess);
@@ -864,7 +862,7 @@ hipError_t hipStreamBeginCapture_common(hipStream_t stream, hipStreamCaptureMode
s->SetCaptureMode(mode);
s->SetOriginStream();
if (mode != hipStreamCaptureModeRelaxed) {
l_captureStreams.push_back(s);
hip::tls.capture_streams_.push_back(s);
}
if (mode == hipStreamCaptureModeGlobal) {
amd::ScopedLock lock(g_captureStreamsLock);
@@ -905,12 +903,13 @@ hipError_t hipStreamEndCapture_common(hipStream_t stream, hipGraph_t* pGraph) {
}
// If mode is not hipStreamCaptureModeRelaxed, hipStreamEndCapture must be called on the stream
// from the same thread
const auto& it = std::find(l_captureStreams.begin(), l_captureStreams.end(), s);
const auto& it = std::find(hip::tls.capture_streams_.begin(),
hip::tls.capture_streams_.end(), s);
if (s->GetCaptureMode() != hipStreamCaptureModeRelaxed) {
if (it == l_captureStreams.end()) {
if (it == hip::tls.capture_streams_.end()) {
return hipErrorStreamCaptureWrongThread;
}
l_captureStreams.erase(it);
hip::tls.capture_streams_.erase(it);
}
if (s->GetCaptureMode() == hipStreamCaptureModeGlobal) {
amd::ScopedLock lock(g_captureStreamsLock);
+55 -27
Просмотреть файл
@@ -82,8 +82,8 @@ static amd::Monitor g_hipInitlock{"hipInit lock"};
HIP_RETURN(hipErrorInvalidDevice); \
} \
} \
if (hip::g_device == nullptr && g_devices.size() > 0) { \
hip::g_device = g_devices[0]; \
if (hip::tls.device_ == nullptr && g_devices.size() > 0) { \
hip::tls.device_ = g_devices[0]; \
amd::Os::setPreferredNumaNode(g_devices[0]->devices()[0]->getPreferredNumaNode()); \
} \
}
@@ -93,8 +93,8 @@ static amd::Monitor g_hipInitlock{"hipInit lock"};
if (!amd::Runtime::initialized()) { \
if (hip::init()) {} \
} \
if (hip::g_device == nullptr && g_devices.size() > 0) { \
hip::g_device = g_devices[0]; \
if (hip::tls.device_ == nullptr && g_devices.size() > 0) { \
hip::tls.device_ = g_devices[0]; \
amd::Os::setPreferredNumaNode(g_devices[0]->devices()[0]->getPreferredNumaNode()); \
} \
}
@@ -130,17 +130,17 @@ static amd::Monitor g_hipInitlock{"hipInit lock"};
HIP_INIT_API_INTERNAL(1, cid, __VA_ARGS__)
#define HIP_RETURN_DURATION(ret, ...) \
hip::g_lastError = ret; \
hip::tls.last_error_ = ret; \
HIPPrintDuration(amd::LOG_INFO, amd::LOG_API, &startTimeUs, \
"%s: Returned %s : %s", \
__func__, ihipGetErrorName(hip::g_lastError), \
__func__, ihipGetErrorName(hip::tls.last_error_), \
ToString( __VA_ARGS__ ).c_str()); \
return hip::g_lastError;
return hip::tls.last_error_;
#define HIP_RETURN(ret, ...) \
hip::g_lastError = ret; \
HIP_ERROR_PRINT(hip::g_lastError, __VA_ARGS__) \
return hip::g_lastError;
hip::tls.last_error_ = ret; \
HIP_ERROR_PRINT(hip::tls.last_error_, __VA_ARGS__) \
return hip::tls.last_error_;
#define HIP_RETURN_ONFAIL(func) \
do { \
@@ -161,12 +161,12 @@ static amd::Monitor g_hipInitlock{"hipInit lock"};
} while (0);
#define CHECK_STREAM_CAPTURE_SUPPORTED() \
if (l_streamCaptureMode == hipStreamCaptureModeThreadLocal) { \
if (l_captureStreams.size() != 0) { \
if (hip::tls.stream_capture_mode_ == hipStreamCaptureModeThreadLocal) { \
if (hip::tls.capture_streams_.size() != 0) { \
HIP_RETURN(hipErrorStreamCaptureUnsupported); \
} \
} else if (l_streamCaptureMode == hipStreamCaptureModeGlobal) { \
if (l_captureStreams.size() != 0) { \
} else if (hip::tls.stream_capture_mode_ == hipStreamCaptureModeGlobal) { \
if (hip::tls.capture_streams_.size() != 0) { \
HIP_RETURN(hipErrorStreamCaptureUnsupported); \
} \
amd::ScopedLock lock(g_captureStreamsLock); \
@@ -205,6 +205,26 @@ namespace hc {
class accelerator;
class accelerator_view;
};
struct ihipExec_t {
dim3 gridDim_;
dim3 blockDim_;
size_t sharedMem_;
hipStream_t hStream_;
std::vector<char> arguments_;
};
class stream_per_thread {
private:
std::vector<hipStream_t> m_streams;
public:
stream_per_thread();
stream_per_thread(const stream_per_thread& ) = delete;
void operator=(const stream_per_thread& ) = delete;
~stream_per_thread();
hipStream_t get();
};
namespace hip {
class Device;
class MemoryPool;
@@ -449,9 +469,26 @@ namespace hip {
void RemoveStreamFromPools(Stream* stream);
};
/// Current thread's device
extern thread_local Device* g_device;
extern thread_local hipError_t g_lastError;
/// Thread Local Storage Variables Aggregator Class
class TlsAggregator {
public:
Device* device_;
std::stack<Device*> ctxt_stack_;
hipError_t last_error_;
std::vector<hip::Stream*> capture_streams_;
hipStreamCaptureMode stream_capture_mode_;
std::stack<ihipExec_t> exec_stack_;
stream_per_thread stream_per_thread_obj_;
TlsAggregator(): device_(nullptr),
last_error_(hipSuccess),
stream_capture_mode_(hipStreamCaptureModeGlobal) {
}
~TlsAggregator() {
}
};
extern thread_local TlsAggregator tls;
/// Device representing the host - for pinned memory
extern Device* host_device;
@@ -475,16 +512,9 @@ namespace hip {
extern bool isValid(hipStream_t& stream);
extern amd::Monitor hipArraySetLock;
extern std::unordered_set<hipArray*> hipArraySet;
};
}; // namespace hip
extern void WaitThenDecrementSignal(hipStream_t stream, hipError_t status, void* user_data);
struct ihipExec_t {
dim3 gridDim_;
dim3 blockDim_;
size_t sharedMem_;
hipStream_t hStream_;
std::vector<char> arguments_;
};
/// Wait all active streams on the blocking queue. The method enqueues a wait command and
/// doesn't stall the current thread
@@ -515,6 +545,4 @@ constexpr bool kMarkerDisableFlush = true; //!< Avoids command batch flush in
extern std::vector<hip::Stream*> g_captureStreams;
extern amd::Monitor g_captureStreamsLock;
extern thread_local std::vector<hip::Stream*> l_captureStreams;
extern thread_local hipStreamCaptureMode l_streamCaptureMode;
#endif // HIP_SRC_HIP_INTERNAL_H
+4 -5
Просмотреть файл
@@ -29,7 +29,6 @@
constexpr unsigned __hipFatMAGIC2 = 0x48495046; // "HIPF"
thread_local std::stack<ihipExec_t> execStack_;
PlatformState* PlatformState::platform_; // Initiaized as nullptr by default
// forward declaration of methods required for __hipRegisrterManagedVar
@@ -920,7 +919,7 @@ hipError_t PlatformState::initStatManagedVarDevicePtr(int deviceId) {
}
void PlatformState::setupArgument(const void* arg, size_t size, size_t offset) {
auto& arguments = execStack_.top().arguments_;
auto& arguments = hip::tls.exec_stack_.top().arguments_;
if (arguments.size() < offset + size) {
arguments.resize(offset + size);
@@ -931,10 +930,10 @@ void PlatformState::setupArgument(const void* arg, size_t size, size_t offset) {
void PlatformState::configureCall(dim3 gridDim, dim3 blockDim, size_t sharedMem,
hipStream_t stream) {
execStack_.push(ihipExec_t{gridDim, blockDim, sharedMem, stream});
hip::tls.exec_stack_.push(ihipExec_t{gridDim, blockDim, sharedMem, stream});
}
void PlatformState::popExec(ihipExec_t& exec) {
exec = std::move(execStack_.top());
execStack_.pop();
exec = std::move(hip::tls.exec_stack_.top());
hip::tls.exec_stack_.pop();
}
+36 -40
Просмотреть файл
@@ -281,55 +281,50 @@ static hipError_t ihipStreamCreate(hipStream_t* stream,
}
// ================================================================================================
class stream_per_thread {
private:
std::vector<hipStream_t> m_streams;
public:
stream_per_thread() {
stream_per_thread::stream_per_thread() {
m_streams.resize(g_devices.size());
for (auto &stream : m_streams) {
stream = nullptr;
}
}
stream_per_thread::~stream_per_thread() {
for (auto &stream:m_streams) {
if (stream != nullptr && hip::isValid(stream)) {
delete reinterpret_cast<hip::Stream*>(stream);
stream = nullptr;
}
}
}
hipStream_t stream_per_thread::get() {
hip::Device* device = hip::getCurrentDevice();
int currDev = device->deviceId();
// This is to make sure m_streams is not empty
if (m_streams.empty()) {
m_streams.resize(g_devices.size());
for (auto &stream : m_streams) {
stream = nullptr;
}
}
stream_per_thread(const stream_per_thread& ) = delete;
void operator=(const stream_per_thread& ) = delete;
~stream_per_thread() {
for (auto &stream:m_streams) {
if (stream != nullptr && hip::isValid(stream)) {
delete reinterpret_cast<hip::Stream*>(stream);
stream = nullptr;
}
// There is a scenario where hipResetDevice destroys stream per thread
// hence isValid check is required to make sure only valid stream is used
if (m_streams[currDev] == nullptr || !hip::isValid(m_streams[currDev])) {
hipError_t status = ihipStreamCreate(&m_streams[currDev], hipStreamDefault,
hip::Stream::Priority::Normal);
if (status != hipSuccess) {
DevLogError("Stream creation failed\n");
}
}
return m_streams[currDev];
}
hipStream_t get() {
hip::Device* device = hip::getCurrentDevice();
int currDev = device->deviceId();
// This is to make sure m_streams is not empty
if (m_streams.empty()) {
m_streams.resize(g_devices.size());
for (auto &stream : m_streams) {
stream = nullptr;
}
}
// There is a scenario where hipResetDevice destroys stream per thread
// hence isValid check is required to make sure only valid stream is used
if (m_streams[currDev] == nullptr || !hip::isValid(m_streams[currDev])) {
hipError_t status = ihipStreamCreate(&m_streams[currDev], hipStreamDefault,
hip::Stream::Priority::Normal);
if (status != hipSuccess) {
DevLogError("Stream creation failed\n");
}
}
return m_streams[currDev];
}
};
thread_local stream_per_thread streamPerThreadObj;
// ================================================================================================
void getStreamPerThread(hipStream_t& stream) {
if (stream == hipStreamPerThread) {
stream = streamPerThreadObj.get();
stream = hip::tls.stream_per_thread_obj_.get();
}
}
@@ -469,9 +464,10 @@ hipError_t hipStreamDestroy(hipStream_t stream) {
if (g_it != g_captureStreams.end()) {
g_captureStreams.erase(g_it);
}
const auto& l_it = std::find(l_captureStreams.begin(), l_captureStreams.end(), s);
if (l_it != l_captureStreams.end()) {
l_captureStreams.erase(l_it);
const auto& l_it = std::find(hip::tls.capture_streams_.begin(),
hip::tls.capture_streams_.end(), s);
if (l_it != hip::tls.capture_streams_.end()) {
hip::tls.capture_streams_.erase(l_it);
}
delete s;
+1 -1
Просмотреть файл
@@ -24,7 +24,7 @@ THE SOFTWARE.
#include "hiprtcInternal.hpp"
namespace hiprtc {
thread_local hiprtcResult g_lastRtcError = HIPRTC_SUCCESS;
thread_local TlsAggregator tls;
}
const char* hiprtcGetErrorString(hiprtcResult x) {
+18 -7
Просмотреть файл
@@ -72,10 +72,10 @@ static amd::Monitor g_hiprtcInitlock {"hiprtcInit lock"};
hiprtc::internal::ToString(__VA_ARGS__).c_str());
#define HIPRTC_RETURN(ret) \
hiprtc::g_lastRtcError = (ret); \
hiprtc::tls.last_rtc_error_ = (ret); \
ClPrint(amd::LOG_INFO, amd::LOG_API, "%s: Returned %s", __func__, \
hiprtcGetErrorString(hiprtc::g_lastRtcError)); \
return hiprtc::g_lastRtcError;
hiprtcGetErrorString(hiprtc::tls.last_rtc_error_)); \
return hiprtc::tls.last_rtc_error_;
namespace hiprtc {
@@ -106,9 +106,9 @@ protected:
// Member Functions
bool findIsa();
// Data Members
std::string name_;
// Data Members
std::string name_;
std::string isa_;
std::string build_log_;
std::vector<char> executable_;
@@ -126,7 +126,7 @@ class RTCCompileProgram : public RTCProgram {
std::string source_name_;
std::map<std::string, std::string> stripped_names_;
std::map<std::string, std::string> demangled_names_;
std::vector<std::string> compile_options_;
std::vector<std::string> link_options_;
@@ -235,4 +235,15 @@ public:
bool LinkComplete(void** bin_out, size_t* size_out);
};
// Thread Local Storage Variables Aggregator Class
class TlsAggregator {
public:
hiprtcResult last_rtc_error_;
TlsAggregator(): last_rtc_error_(HIPRTC_SUCCESS) {
}
~TlsAggregator() {
}
};
extern thread_local TlsAggregator tls;
} // namespace hiprtc
+3 -3
Просмотреть файл
@@ -56,10 +56,10 @@ extern "C" char * __cxa_demangle(const char *mangled_name, char *output_buffer,
HIP_INIT_VOID();
#define HIPRTC_RETURN(ret) \
hiprtc::g_lastRtcError = ret; \
hiprtc::tls.last_rtc_error_ = ret; \
ClPrint(amd::LOG_INFO, amd::LOG_API, "%s: Returned %s", __func__, \
hiprtcGetErrorString(hiprtc::g_lastRtcError)); \
return hiprtc::g_lastRtcError;
hiprtcGetErrorString(hiprtc::tls.last_rtc_error_)); \
return hiprtc::tls.last_rtc_error_;
#endif // HIPRTC_SRC_HIP_INTERNAL_H