diff --git a/hipamd/api/hip/hip_peer.cpp b/hipamd/api/hip/hip_peer.cpp index 14a41f9953..e986cfeac0 100644 --- a/hipamd/api/hip/hip_peer.cpp +++ b/hipamd/api/hip/hip_peer.cpp @@ -53,7 +53,31 @@ hipError_t hipMemcpyPeerAsync(void* dst, hipCtx_t dstDevice, const void* src, hi hipError_t hipDeviceCanAccessPeer(int* canAccessPeer, int deviceId, int peerDeviceId) { HIP_INIT_API(canAccessPeer, deviceId, peerDeviceId); - *canAccessPeer = 0; + amd::Device* device = nullptr; + amd::Device* peer_device = nullptr; + + if (canAccessPeer == nullptr) { + HIP_RETURN(hipErrorInvalidValue); + } + + /* Peer cannot be self */ + if (deviceId == peerDeviceId) { + *canAccessPeer = 0; + return HIP_RETURN(hipSuccess); + } + + /* Cannot exceed the max number of devices */ + if (static_cast(deviceId) >= g_devices.size() + || static_cast(peerDeviceId) >= g_devices.size()) { + return HIP_RETURN(hipErrorInvalidValue); + } + + device = g_devices[deviceId]->devices()[0]; + peer_device = g_devices[peerDeviceId]->devices()[0]; + + *canAccessPeer = static_cast(std::find(device->p2pDevices_.begin(), + device->p2pDevices_.end(), as_cl(peer_device)) + != device->p2pDevices_.end()); return HIP_RETURN(hipSuccess); } @@ -61,17 +85,13 @@ hipError_t hipDeviceCanAccessPeer(int* canAccessPeer, int deviceId, int peerDevi hipError_t hipDeviceDisablePeerAccess(int peerDeviceId) { HIP_INIT_API(peerDeviceId); - assert(0 && "Unimplemented"); - - HIP_RETURN(hipErrorUnknown); + HIP_RETURN(hipSuccess); } hipError_t hipDeviceEnablePeerAccess(int peerDeviceId, unsigned int flags) { HIP_INIT_API(peerDeviceId, flags); - assert(0 && "Unimplemented"); - - HIP_RETURN(hipErrorUnknown); + HIP_RETURN(hipSuccess); } hipError_t hipMemcpyPeer(void* dst, int dstDevice, const void* src, int srcDevice, @@ -95,15 +115,11 @@ hipError_t hipMemcpyPeerAsync(void* dst, int dstDevice, const void* src, int src hipError_t hipCtxEnablePeerAccess(hipCtx_t peerCtx, unsigned int flags) { HIP_INIT_API(peerCtx, flags); - assert(0 && "Unimplemented"); - - HIP_RETURN(hipErrorUnknown); + HIP_RETURN(hipSuccess); } hipError_t hipCtxDisablePeerAccess(hipCtx_t peerCtx) { HIP_INIT_API(peerCtx); - assert(0 && "Unimplemented"); - - HIP_RETURN(hipErrorUnknown); + HIP_RETURN(hipSuccess); }