diff --git a/projects/hip/api/hip/hip_platform.cpp b/projects/hip/api/hip/hip_platform.cpp index 5203c8123a..73055a1cb9 100644 --- a/projects/hip/api/hip/hip_platform.cpp +++ b/projects/hip/api/hip/hip_platform.cpp @@ -64,6 +64,16 @@ hipError_t hipModuleGetGlobal(hipDeviceptr_t* dptr, size_t* bytes, hipError_t ihipCreateGlobalVarObj(const char* name, hipModule_t hmod, amd::Memory** amd_mem_obj, hipDeviceptr_t* dptr, size_t* bytes); +static bool isCompatibleCodeObject(const std::string& codeobj_target_id, + const char* device_name) { + // Workaround for gfx906 device name mismatch. + // If bundle target id starts with gfx906 and device name starts with + // gfx906, treat them as match. + return codeobj_target_id.compare(device_name) == 0 || + (codeobj_target_id.find("gfx906") == 0 && + std::string(device_name).find("gfx906") == 0); +} + extern "C" std::vector< std::pair >* __hipRegisterFatBinary(const void* data) { HIP_INIT(); @@ -102,12 +112,7 @@ extern "C" std::vector< std::pair >* __hipRegisterFatBinary(c for (size_t dev = 0; dev < g_devices.size(); ++dev) { amd::Context* ctx = g_devices[dev]; - if (target.compare(ctx->devices()[0]->info().name_)) { - // Workaround for gfx906 device name mismatch. - // If bundle target id starts with gfx906 and device name starts with - // gfx906, treat them as match. - if (target.find("gfx906") != 0 || - std::string(ctx->devices()[0]->info().name_).find("gfx906") != 0) + if (!isCompatibleCodeObject(target, ctx->devices()[0]->info().name_)) { continue; } @@ -579,7 +584,7 @@ const std::vector& modules() { std::string target(desc->triple + sizeof(HCC_AMDGCN_AMDHSA_TRIPLE), desc->tripleSize - sizeof(HCC_AMDGCN_AMDHSA_TRIPLE)); - if (!target.compare(hip::getCurrentContext()->devices()[0]->info().name_)) { + if (isCompatibleCodeObject(target, hip::getCurrentContext()->devices()[0]->info().name_)) { hipModule_t module; if (hipSuccess == hipModuleLoadData(&module, reinterpret_cast( reinterpret_cast(obheader) + desc->offset)))