diff --git a/hipamd/api/hip/hip_module.cpp b/hipamd/api/hip/hip_module.cpp index 3837a35daf..2163ef3da4 100644 --- a/hipamd/api/hip/hip_module.cpp +++ b/hipamd/api/hip/hip_module.cpp @@ -93,8 +93,16 @@ hipError_t hipModuleLoadData(hipModule_t *module, const void *image) HIP_RETURN(ihipModuleLoadData(module, image)); } +extern bool __hipExtractCodeObjectFromFatBinary(const void* data, + const std::vector& devices, + std::vector>& code_objs); + hipError_t ihipModuleLoadData(hipModule_t *module, const void *image) { + std::vector> code_objs; + if (__hipExtractCodeObjectFromFatBinary(image, {hip::getCurrentContext()->devices()[0]->info().name_}, code_objs)) + image = code_objs[0].first; + amd::Program* program = new amd::Program(*hip::getCurrentContext()); if (program == NULL) { return hipErrorOutOfMemory; diff --git a/hipamd/api/hip/hip_platform.cpp b/hipamd/api/hip/hip_platform.cpp index 73055a1cb9..f09bc1dea9 100644 --- a/hipamd/api/hip/hip_platform.cpp +++ b/hipamd/api/hip/hip_platform.cpp @@ -74,26 +74,23 @@ static bool isCompatibleCodeObject(const std::string& codeobj_target_id, std::string(device_name).find("gfx906") == 0); } -extern "C" std::vector< std::pair >* __hipRegisterFatBinary(const void* data) +// Extracts code objects from fat binary in data for device names given in devices. +// Returns true if code objects are extracted successfully. +bool __hipExtractCodeObjectFromFatBinary(const void* data, + const std::vector& devices, + std::vector>& code_objs) { HIP_INIT(); - if(g_devices.empty()) { - return nullptr; - } - const __CudaFatBinaryWrapper* fbwrapper = reinterpret_cast(data); - if (fbwrapper->magic != __hipFatMAGIC2 || fbwrapper->version != 1) { - return nullptr; - } - std::string magic((char*)fbwrapper->binary, sizeof(CLANG_OFFLOAD_BUNDLER_MAGIC_STR) - 1); + std::string magic((const char*)data, sizeof(CLANG_OFFLOAD_BUNDLER_MAGIC_STR) - 1); if (magic.compare(CLANG_OFFLOAD_BUNDLER_MAGIC_STR)) { - return nullptr; + return false; } - auto programs = new std::vector< std::pair >{g_devices.size()}; - - const auto obheader = reinterpret_cast(fbwrapper->binary); + code_objs.resize(devices.size()); + const auto obheader = reinterpret_cast(data); const auto* desc = &obheader->desc[0]; + unsigned num_code_objs = 0; for (uint64_t i = 0; i < obheader->numBundles; ++i, desc = reinterpret_cast( reinterpret_cast(&desc->triple[0]) + desc->tripleSize)) { @@ -109,20 +106,54 @@ extern "C" std::vector< std::pair >* __hipRegisterFatBinary(c reinterpret_cast(obheader) + desc->offset); size_t size = desc->size; - for (size_t dev = 0; dev < g_devices.size(); ++dev) { - amd::Context* ctx = g_devices[dev]; + for (size_t dev = 0; dev < devices.size(); ++dev) { + const char* name = devices[dev]; - if (!isCompatibleCodeObject(target, ctx->devices()[0]->info().name_)) { + if (!isCompatibleCodeObject(target, name)) { continue; } + code_objs[dev] = std::make_pair(image, size); + num_code_objs++; + } + } + if (num_code_objs == devices.size()) + return true; + else + return false; +} - amd::Program* program = new amd::Program(*ctx); - if (program == nullptr) { - return nullptr; - } - if (CL_SUCCESS == program->addDeviceProgram(*ctx->devices()[0], image, size)) { - programs->at(dev) = std::make_pair(reinterpret_cast(as_cl(program)) , false); - } +extern "C" std::vector< std::pair >* __hipRegisterFatBinary(const void* data) +{ + HIP_INIT(); + + if(g_devices.empty()) { + return nullptr; + } + const __CudaFatBinaryWrapper* fbwrapper = reinterpret_cast(data); + if (fbwrapper->magic != __hipFatMAGIC2 || fbwrapper->version != 1) { + return nullptr; + } + + std::vector devices; + std::vector> code_objs; + for (size_t dev = 0; dev < g_devices.size(); ++dev) { + amd::Context* ctx = g_devices[dev]; + devices.push_back(ctx->devices()[0]->info().name_); + } + + if (!__hipExtractCodeObjectFromFatBinary((char*)fbwrapper->binary, devices, code_objs)) { + return nullptr; + } + + auto programs = new std::vector< std::pair >{g_devices.size()}; + for (size_t dev = 0; dev < g_devices.size(); ++dev) { + amd::Context* ctx = g_devices[dev]; + amd::Program* program = new amd::Program(*ctx); + if (program == nullptr) { + return nullptr; + } + if (CL_SUCCESS == program->addDeviceProgram(*ctx->devices()[0], code_objs[dev].first, code_objs[dev].second)) { + programs->at(dev) = std::make_pair(reinterpret_cast(as_cl(program)) , false); } }