SWDEV-550521 - Add the JIT options for HIPRTC linker APIs (#762)

* SWDEV-550521 - Add the JIT options for HIPRTC linker APIs

* Address review comments about using C++ datatypes
This commit is contained in:
Satyanvesh Dittakavi
2025-09-09 12:24:08 +05:30
committato da GitHub
parent 662ee1c7e1
commit 85065dab32
2 ha cambiato i file con 146 aggiunte e 7 eliminazioni
+103 -1
Vedi File
@@ -874,6 +874,12 @@ bool IsCompatibleWithGenericTarget(const std::string& coTarget, const std::strin
std::vector<std::string> getLinkOptions(const LinkArguments& args) {
std::vector<std::string> res;
{ // process optimization level
std::string opt("-O");
opt += std::to_string(args.optimization_level_);
res.push_back(opt);
}
const auto irArgCount = args.linker_ir2isa_args_count_;
if (irArgCount > 0) {
res.reserve(irArgCount);
@@ -1010,12 +1016,108 @@ bool LinkProgram::AddLinkerOptions(unsigned int num_options, hipJitOption* optio
return false;
}
switch (options_ptr[opt_idx]) {
case hipJitOptionMaxRegisters:
link_args_.max_registers_ = *(reinterpret_cast<uint64_t*>(&options_vals_ptr[opt_idx]));
break;
case hipJitOptionThreadsPerBlock:
link_args_.threads_per_block_ =
*(reinterpret_cast<uint64_t*>(&options_vals_ptr[opt_idx]));
break;
case hipJitOptionWallTime:
link_args_.wall_time_ = *(reinterpret_cast<float*>(options_vals_ptr[opt_idx]));
break;
case hipJitOptionInfoLogBuffer: {
link_args_.info_log_ = (reinterpret_cast<char*>(options_vals_ptr[opt_idx]));
break;
}
case hipJitOptionInfoLogBufferSizeBytes:
link_args_.info_log_size_ = (reinterpret_cast<uint64_t>(options_vals_ptr[opt_idx]));
break;
case hipJitOptionErrorLogBuffer: {
link_args_.error_log_ = reinterpret_cast<char*>(options_vals_ptr[opt_idx]);
break;
}
case hipJitOptionErrorLogBufferSizeBytes:
link_args_.error_log_size_ = (reinterpret_cast<uint64_t>(options_vals_ptr[opt_idx]));
break;
case hipJitOptionOptimizationLevel:
link_args_.optimization_level_ =
*(reinterpret_cast<uint64_t*>(&options_vals_ptr[opt_idx]));
break;
case hipJitOptionTargetFromContext:
link_args_.target_from_hip_context_ =
*(reinterpret_cast<uint64_t*>(&options_vals_ptr[opt_idx]));
break;
case hipJitOptionTarget:
link_args_.jit_target_ = *(reinterpret_cast<uint64_t*>(&options_vals_ptr[opt_idx]));
break;
case hipJitOptionFallbackStrategy:
link_args_.fallback_strategy_ =
*(reinterpret_cast<uint64_t*>(&options_vals_ptr[opt_idx]));
break;
case hipJitOptionGenerateDebugInfo:
link_args_.generate_debug_info_ = *(reinterpret_cast<uint32_t*>(&options_vals_ptr[opt_idx]));
break;
case hipJitOptionLogVerbose:
link_args_.log_verbose_ = reinterpret_cast<uint64_t>(options_vals_ptr[opt_idx]);
break;
case hipJitOptionGenerateLineInfo:
link_args_.generate_line_info_ = *(reinterpret_cast<uint32_t*>(&options_vals_ptr[opt_idx]));
break;
case hipJitOptionCacheMode:
link_args_.cache_mode_ = *(reinterpret_cast<uint64_t*>(&options_vals_ptr[opt_idx]));
break;
case hipJitOptionSm3xOpt:
link_args_.sm3x_opt_ = *(reinterpret_cast<bool*>(&options_vals_ptr[opt_idx]));
break;
case hipJitOptionFastCompile:
link_args_.fast_compile_ = *(reinterpret_cast<bool*>(&options_vals_ptr[opt_idx]));
break;
case hipJitOptionGlobalSymbolNames: {
link_args_.global_symbol_names_ = reinterpret_cast<const char**>(options_vals_ptr[opt_idx]);
break;
}
case hipJitOptionGlobalSymbolAddresses: {
link_args_.global_symbol_addresses_ = reinterpret_cast<void**>(options_vals_ptr[opt_idx]);
break;
}
case hipJitOptionGlobalSymbolCount:
link_args_.global_symbol_count_ =
*(reinterpret_cast<uint64_t*>(&options_vals_ptr[opt_idx]));
break;
case hipJitOptionLto:
link_args_.lto_ = *(reinterpret_cast<uint32_t*>(&options_vals_ptr[opt_idx]));
break;
case hipJitOptionFtz:
link_args_.ftz_ = *(reinterpret_cast<uint32_t*>(&options_vals_ptr[opt_idx]));
break;
case hipJitOptionPrecDiv:
link_args_.prec_div_ = *(reinterpret_cast<uint32_t*>(&options_vals_ptr[opt_idx]));
break;
case hipJitOptionPrecSqrt:
link_args_.prec_sqrt_ = *(reinterpret_cast<uint32_t*>(&options_vals_ptr[opt_idx]));
break;
case hipJitOptionFma:
link_args_.fma_ = *(reinterpret_cast<uint32_t*>(&options_vals_ptr[opt_idx]));
break;
case hipJitOptionPositionIndependentCode:
link_args_.pic_ = *(reinterpret_cast<uint32_t*>(&options_vals_ptr[opt_idx]));
break;
case hipJitOptionMinCTAPerSM:
link_args_.min_cta_per_sm_ = *(reinterpret_cast<uint32_t*>(&options_vals_ptr[opt_idx]));
break;
case hipJitOptionMaxThreadsPerBlock:
link_args_.max_threads_per_block_ = *(reinterpret_cast<uint32_t*>(&options_vals_ptr[opt_idx]));
break;
case hipJitOptionOverrideDirectiveValues:
link_args_.override_directive_values_ = *(reinterpret_cast<uint32_t*>(&options_vals_ptr[opt_idx]));
break;
case hipJitOptionIRtoISAOptExt: {
link_args_.linker_ir2isa_args_ = reinterpret_cast<const char**>(options_vals_ptr[opt_idx]);
break;
}
case hipJitOptionIRtoISAOptCountExt:
link_args_.linker_ir2isa_args_count_ = reinterpret_cast<size_t>(options_vals_ptr[opt_idx]);
link_args_.linker_ir2isa_args_count_ = reinterpret_cast<uint64_t>(options_vals_ptr[opt_idx]);
break;
default:
break;
@@ -80,12 +80,49 @@ const std::map<std::string, std::string>& GenericTargetMapping();
// Both targets should not have any feature.
bool IsCompatibleWithGenericTarget(const std::string& coTarget, const std::string& agentTarget);
} // namespace helpers
/**
* HIPRTC linker options
*/
struct LinkArguments {
const char** linker_ir2isa_args_;
size_t linker_ir2isa_args_count_;
LinkArguments() : linker_ir2isa_args_{nullptr}, linker_ir2isa_args_count_{0} {}
uint64_t max_registers_ = 0; ///< Maximum registers that a thread may a use
uint64_t threads_per_block_ = 0; ///< Minimum No. of threads per block
float wall_time_ = 0.0f; ///< Value for total wall clock time
char* info_log_ = nullptr; ///< Pointer to a buffer to print log information
uint64_t info_log_size_ = 0; ///< Size of the buffer in bytes for logged info
char* error_log_ = nullptr; ///< Pointer to a buffer to print log errors
uint64_t error_log_size_ = 0; ///< Size of the buffer in bytes for logged errors
uint64_t optimization_level_ = 3; ///< Value of the optimization level for generated code
///< acceptable options -O0, -O1, -O2, -O3
uint64_t target_from_hip_context_ = 0; ///< Determines the target, based on the current context
uint64_t jit_target_= 0; ///< CUDA Only JIT target
uint64_t fallback_strategy_ = 0; ///< CUDA Only Choice of fallback strategy
uint32_t generate_debug_info_ = 0; ///< Create debug information in output -g, if set
uint64_t log_verbose_ = 0; ///< Generate verbose log messages
uint32_t generate_line_info_ = 0; ///< Generate line number information
uint64_t cache_mode_ = 0; ///< CUDA Only Enables caching explicitly
bool sm3x_opt_ = false; ///< CUDA Only New SM3X option
bool fast_compile_ = false; ///< CUDA Only Set fast compile
const char** global_symbol_names_ = nullptr; ///< Array of device symbol names to be relocated
///< to the host
void** global_symbol_addresses_ = nullptr; ///< Array of host addresses to be relocated to the
///< device
uint64_t global_symbol_count_ = 0; ///< Number of symbol count
int32_t lto_ = 0; ///< Enable link time optimization for device code
int32_t ftz_ = 0; ///< Set single-precision denormals
int32_t prec_div_ = 1; ///< Set single-precision floating-point division
///< and reciprocals
int32_t prec_sqrt_ = 1; ///< Set single-precision floating-point square root
int32_t fma_ = 1; ///< Enable floating-point multiplies and
///< adds/subtracts operations
int32_t pic_ = 0; ///< Generates Position Independent code
int32_t min_cta_per_sm_ = 0; ///< Hints to JIT compiler the minimum number of
///< CTAs from kernel's grid to be mapped to SM
int32_t max_threads_per_block_ = 0; ///< Maximum number of threads in a thread block
int32_t override_directive_values_ = 0; ///< Override Directive values
const char** linker_ir2isa_args_ = nullptr; ///< Hip Only Linker options to be passed on
///< to compiler
uint64_t linker_ir2isa_args_count_ = 0; ///< Hip Only Count of linker options to be passed
///< on to compiler
};
class RTCProgram {
@@ -119,7 +156,7 @@ class LinkProgram : public RTCProgram {
amd_comgr_data_kind_t data_kind_;
amd_comgr_data_kind_t GetCOMGRDataKind(hipJitInputType input_type);
// Linker Argumenets at hipLinkCreate
// Linker Arguments at hipLinkCreate
LinkArguments link_args_;
// Spirv is bundled