diff --git a/src/ipc/context_ipc_tmpl_device.hpp b/src/ipc/context_ipc_tmpl_device.hpp index 28cd718f0a..5e697cafaf 100644 --- a/src/ipc/context_ipc_tmpl_device.hpp +++ b/src/ipc/context_ipc_tmpl_device.hpp @@ -70,13 +70,21 @@ __device__ void IPCContext::get_nbi(T *dest, const T *source, size_t nelems, // Atomics template -__device__ void IPCContext::amo_add(void *dst, T value, int pe) { - assert(false); +__device__ void IPCContext::amo_add(void *dest, T value, int pe) { + int local_pe = pe % ipcImpl_.shm_size; + uint64_t L_offset = + reinterpret_cast(dest) - ipcImpl_.ipc_bases[my_pe]; + ipcImpl_.ipcAMOAdd( + reinterpret_cast(ipcImpl_.ipc_bases[local_pe] + L_offset), value); } template -__device__ void IPCContext::amo_set(void *dst, T value, int pe) { - assert(false); +__device__ void IPCContext::amo_set(void *dest, T value, int pe) { + int local_pe = pe % ipcImpl_.shm_size; + uint64_t L_offset = + reinterpret_cast(dest) - ipcImpl_.ipc_bases[my_pe]; + ipcImpl_.ipcAMOSet( + reinterpret_cast(ipcImpl_.ipc_bases[local_pe] + L_offset), value); } template @@ -119,20 +127,32 @@ __device__ void IPCContext::amo_xor(void *dst, T value, int pe) { } template -__device__ void IPCContext::amo_cas(void *dst, T value, T cond, int pe) { - assert(false); +__device__ void IPCContext::amo_cas(void *dest, T value, T cond, int pe) { + int local_pe = pe % ipcImpl_.shm_size; + uint64_t L_offset = + reinterpret_cast(dest) - ipcImpl_.ipc_bases[my_pe]; + ipcImpl_.ipcAMOCas( + reinterpret_cast(ipcImpl_.ipc_bases[local_pe] + L_offset), cond, + value); } template -__device__ T IPCContext::amo_fetch_add(void *dst, T value, int pe) { - assert(false); - return 0; +__device__ T IPCContext::amo_fetch_add(void *dest, T value, int pe) { + int local_pe = pe % ipcImpl_.shm_size; + uint64_t L_offset = + reinterpret_cast(dest) - ipcImpl_.ipc_bases[my_pe]; + return ipcImpl_.ipcAMOFetchAdd( + reinterpret_cast(ipcImpl_.ipc_bases[local_pe] + L_offset), value); } template -__device__ T IPCContext::amo_fetch_cas(void *dst, T value, T cond, int pe) { - assert(false); - return 0; +__device__ T IPCContext::amo_fetch_cas(void *dest, T value, T cond, int pe) { + int local_pe = pe % ipcImpl_.shm_size; + uint64_t L_offset = + reinterpret_cast(dest) - ipcImpl_.ipc_bases[my_pe]; + return ipcImpl_.ipcAMOFetchCas( + reinterpret_cast(ipcImpl_.ipc_bases[local_pe] + L_offset), cond, + value); } // Collectives