diff --git a/runtime/hsa-runtime/core/driver/kfd/amd_kfd_driver.cpp b/runtime/hsa-runtime/core/driver/kfd/amd_kfd_driver.cpp index 935007b6a3..723925f1bc 100644 --- a/runtime/hsa-runtime/core/driver/kfd/amd_kfd_driver.cpp +++ b/runtime/hsa-runtime/core/driver/kfd/amd_kfd_driver.cpp @@ -692,5 +692,29 @@ hsa_status_t KfdDriver::GetWallclockFrequency(uint32_t node_id, uint64_t* freque return HSA_STATUS_SUCCESS; } +hsa_status_t KfdDriver::ShareMemory(void* mem, size_t size, + HsaSharedMemoryHandle* share_mem) const { + assert(share_mem); + + if (HSAKMT_CALL(hsaKmtShareMemory(mem, size, share_mem)) != HSAKMT_STATUS_SUCCESS) { + return HSA_STATUS_ERROR; + } + + return HSA_STATUS_SUCCESS; +} + +hsa_status_t KfdDriver::RegisterSharedHandle(const HsaSharedMemoryHandle* share_mem, void** mem, + uint64_t* size) const { + assert(share_mem); + assert(mem); + assert(size); + + if (HSAKMT_CALL(hsaKmtRegisterSharedHandle(share_mem, mem, size)) != HSAKMT_STATUS_SUCCESS) { + return HSA_STATUS_ERROR; + } + + return HSA_STATUS_SUCCESS; +} + } // namespace AMD } // namespace rocr diff --git a/runtime/hsa-runtime/core/inc/amd_kfd_driver.h b/runtime/hsa-runtime/core/inc/amd_kfd_driver.h index ca8ee8a593..da1a6501e4 100644 --- a/runtime/hsa-runtime/core/inc/amd_kfd_driver.h +++ b/runtime/hsa-runtime/core/inc/amd_kfd_driver.h @@ -136,6 +136,9 @@ public: const HsaMemMapFlags* mem_flags, uint32_t num_nodes, const uint32_t* nodes) const override; hsa_status_t MakeMemoryUnresident(const void* mem) const override; + hsa_status_t ShareMemory(void* mem, size_t size, HsaSharedMemoryHandle* share_mem) const override; + hsa_status_t RegisterSharedHandle(const HsaSharedMemoryHandle* share_mem, void** mem, + uint64_t* size) const override; hsa_status_t OpenSMI(uint32_t node_id, int* fd) const override; diff --git a/runtime/hsa-runtime/core/inc/driver.h b/runtime/hsa-runtime/core/inc/driver.h index 3a4082e24f..f97e4b479a 100644 --- a/runtime/hsa-runtime/core/inc/driver.h +++ b/runtime/hsa-runtime/core/inc/driver.h @@ -351,6 +351,25 @@ public: /// @return HSA_STATUS_SUCCESS if the driver successfully makes the memory virtual hsa_status_t MakeMemoryUnresident(const void* mem) const = 0; + /// @brief Shares memory with another process. + /// @param[in] mem Pointer to the memory to be shared. + /// @param[in] size Size of the memory to be shared. + /// @param[out] share_mem Pointer to the shared memory handle. + /// @return HSA_STATUS_SUCCESS if the memory was successfully shared, or an error code. + virtual hsa_status_t ShareMemory(void* mem, size_t size, HsaSharedMemoryHandle* share_mem) const { + return HSA_STATUS_ERROR_INVALID_AGENT; + } + + /// @brief Registers a shared memory handle. + /// @param[in] share_mem Pointer to the shared memory handle. + /// @param[out] mem Pointer to the memory. + /// @param[out] size Size of the memory. + /// @return HSA_STATUS_SUCCESS if the memory was successfully registered, or an error code. + virtual hsa_status_t RegisterSharedHandle(const HsaSharedMemoryHandle* share_mem, void** mem, + uint64_t* size) const { + return HSA_STATUS_ERROR_INVALID_AGENT; + } + /// Unique identifier for supported kernel-mode drivers. const DriverType kernel_driver_type_; diff --git a/runtime/hsa-runtime/core/inc/runtime.h b/runtime/hsa-runtime/core/inc/runtime.h index 4b7e489d94..b9e45a8350 100644 --- a/runtime/hsa-runtime/core/inc/runtime.h +++ b/runtime/hsa-runtime/core/inc/runtime.h @@ -508,6 +508,22 @@ class Runtime { return **driver; } + /// @brief Check if the drivers of the agents are different. + /// @param [in] agents Array of agents to check. + /// @param [in] num_agents Number of agents in the array. + /// @return True if the drivers of the agents are different, false otherwise. + static bool IsDifferentDriver(Agent* agents, uint32_t num_agents) { + if (num_agents == 0 || agents == nullptr) return true; + + auto first_driver_type = agents[0].driver().kernel_driver_type_; + for (uint32_t i = 1; i < num_agents; ++i) { + if (agents[i].driver().kernel_driver_type_ != first_driver_type) { + return true; + } + } + return false; + } + std::vector>& AgentDrivers() { return agent_drivers_; } static bool IsGPUDriver(DriverType driver_type) { diff --git a/runtime/hsa-runtime/core/runtime/runtime.cpp b/runtime/hsa-runtime/core/runtime/runtime.cpp index 14b57254bb..b2c8be1f4f 100644 --- a/runtime/hsa-runtime/core/runtime/runtime.cpp +++ b/runtime/hsa-runtime/core/runtime/runtime.cpp @@ -1216,6 +1216,7 @@ hsa_status_t Runtime::IPCCreate(void* ptr, size_t len, hsa_amd_ipc_memory_t* han if (info.agentBaseAddress != ptr || info.sizeInBytes != len) return HSA_STATUS_ERROR_INVALID_ARGUMENT; + Agent* agent = Agent::Convert(info.agentOwner); bool useFrag = (block.base != ptr || block.length != len); // Assume all pointers and blocks are 4Kb aligned. uint32_t fragOffset = (reinterpret_cast(ptr) - @@ -1229,7 +1230,7 @@ hsa_status_t Runtime::IPCCreate(void* ptr, size_t len, hsa_amd_ipc_memory_t* han if (!ipc_dmabuf_supported_) { HsaSharedMemoryHandle *sHandle = reinterpret_cast(handle); - if (HSAKMT_CALL(hsaKmtShareMemory(block.base, block.length, sHandle)) != HSAKMT_STATUS_SUCCESS) + if (agent->driver().ShareMemory(block.base, block.length, sHandle) != HSA_STATUS_SUCCESS) return HSA_STATUS_ERROR_INVALID_ARGUMENT; hsa_status_t err = HSA_STATUS_SUCCESS; @@ -1250,7 +1251,6 @@ hsa_status_t Runtime::IPCCreate(void* ptr, size_t len, hsa_amd_ipc_memory_t* han handle->handle[1] = dmaBufFdHandleHi; handle->handle[2] = getpid(); // socket server name handle - Agent *agent = Agent::Convert(info.agentOwner); handle->handle[3] = agent->device_type() == Agent::kAmdCpuDevice; // System sub allocations are not supported for now. if (handle->handle[3] && useFrag) return HSA_STATUS_ERROR_INVALID_ARGUMENT; @@ -1391,6 +1391,9 @@ hsa_status_t Runtime::IPCAttach(const hsa_amd_ipc_memory_t* handle, size_t len, bool isFragment = false; uint32_t fragOffset = 0; + if (Runtime::IsDifferentDriver(*agents, num_agents)) return HSA_STATUS_ERROR_INVALID_ARGUMENT; + core::Driver* driver = &agents[0]->driver(); + auto fixFragment = [&](amdgpu_bo_handle ldrm_bo) { if (isFragment) { importAddress = reinterpret_cast(importAddress) + fragOffset; @@ -1402,14 +1405,17 @@ hsa_status_t Runtime::IPCAttach(const hsa_amd_ipc_memory_t* handle, size_t len, allocation_map_[importAddress].ldrm_bo = ldrm_bo; }; - auto importMemory = [&](unsigned int numNodes, HSAuint32 *nodes, - amdgpu_bo_import_result *res) { - int ret = ipc_dmabuf_supported_ ? - IPCClientImport(importHandle.handle[2], dmaBufFDHandle, res, - numNodes, nodes, &importAddress, &importSize) : - HSAKMT_CALL(hsaKmtRegisterSharedHandle(reinterpret_cast(&importHandle), - &importAddress, &importSize)); - if (ret != HSAKMT_STATUS_SUCCESS) return HSA_STATUS_ERROR_INVALID_ARGUMENT; + auto importMemory = [&](unsigned int numNodes, HSAuint32* nodes, amdgpu_bo_import_result* res) { + if (ipc_dmabuf_supported_) { + int ret = IPCClientImport(importHandle.handle[2], dmaBufFDHandle, res, numNodes, nodes, + &importAddress, &importSize); + if (ret != HSAKMT_STATUS_SUCCESS) return HSA_STATUS_ERROR_INVALID_ARGUMENT; + } else { + hsa_status_t ret = driver->RegisterSharedHandle( + reinterpret_cast(&importHandle), &importAddress, + &importSize); + if (ret != HSA_STATUS_SUCCESS) return ret; + } return HSA_STATUS_SUCCESS; };