diff --git a/src/transport/net_ib.cc b/src/transport/net_ib.cc index 9bfd8dcf..4d3f0a08 100644 --- a/src/transport/net_ib.cc +++ b/src/transport/net_ib.cc @@ -29,6 +29,7 @@ #include "ibvwrap.h" #include "mlx5/mlx5dvwrap.h" +#include "ionic/ionicdvwrap.h" #include "graph/xml.h" #define MAXSUFFIXSIZE 16 @@ -110,16 +111,38 @@ struct ncclIbMergedDev ncclIbMergedDevs[MAX_IB_VDEVS]; struct ncclIbDev ncclIbDevs[MAX_IB_DEVS]; static std::mutex ncclIbMutex; static int ncclIbRelaxedOrderingEnabled = 0; +static bool rcclAinicRoce = 0; +static bool rcclCtsInlineData = 0; +static bool rcclCtsOffloadEnabled = 0; +static bool ncclIbUseInline = 0; +static int ncclIbGdrFlushDisable = 0; + +enum ncclIbChannelType { + ncclIbChannelTypeCts = 0, + ncclIbChannelTypeData = 1, + ncclIbChannelTypeMax = 2 +}; + +struct ncclChannelToUd { + int channelId; + bool udId; + bool udAllocated; +}; + +static ncclChannelToUd nccl_channel_ud_map[MAXCHANNELS][ncclIbChannelTypeMax]; +static bool nccl_channel_last_ud[MAX_IB_DEVS][ncclIbChannelTypeMax]; // With ncclNet_v11_t the NCCL core initializes the network plugin per-communicator // rather than once for all communicators. However, the internal plugin implementation // still assumes the plugin is initialized only once across all communicators. The ref // counter makes sure the plugin internally initializes only once. When per communicator // context support is added to the plugin the ref counter can be removed. static int netRefCount; #define NCCL_IB_LLSTR(ll) (((ll) == IBV_LINK_LAYER_INFINIBAND) ? "IB" : (((ll) == IBV_LINK_LAYER_ETHERNET) ? "RoCE" : "UNSPECIFIED")) +#define NCCL_CTS_QP_SLOT_INVALID 0xFF + #define NCCL_IB_SL_DEFAULT 0 #define NCCL_IB_TC_DEFAULT 0 @@ -141,6 +164,13 @@ NCCL_PARAM(IbEceEnable,"IB_ECE_ENABLE",1); NCCL_PARAM(IbDataDirect,"IB_DATA_DIRECT",1); NCCL_PARAM(IbQpsPerConn, "IB_QPS_PER_CONNECTION", 1); RCCL_PARAM(IbQpsPerP2p, "IB_QPS_PER_P2P", 0); +NCCL_PARAM(IbGdrFlushDisable, "GDR_FLUSH_DISABLE", 0); + +// AMD AINIC +RCCL_PARAM(CtsInlineData, "CTS_INLINE_DATA", -1); +RCCL_PARAM(CtsOffloadEnabled, "CTS_OFFLOAD_ENABLED", -1); + +extern int64_t rcclParamAinicRoce(); static ncclResult_t ncclIbStatsInit(struct ncclIbStats* stat) { __atomic_store_n(&stat->fatalErrorCount, 0, __ATOMIC_RELAXED); @@ -779,6 +809,10 @@ ncclResult_t ncclIbInit(void** ctx, uint64_t commId, ncclNetCommConfig_t* config static int shownIbHcaEnv = 0; if(wrap_ibv_symbols() != ncclSuccess) { return ncclInternalError; } if(wrap_mlx5dv_symbols() != ncclSuccess) { INFO(NCCL_NET, "NET/IB : Failed to open mlx5dv symbols. Advance features like CX-8 Direct-NIC will be disabled."); } + if(wrap_ionicdv_symbols() != ncclSuccess) { + WARN("NET/IB : Failed to open ionicdv symbols. Advance features like AINIC UD load balancing will be disabled."); + return ncclInternalError; + } // Detect IB cards int nIbDevs = 0; @@ -944,6 +978,23 @@ ncclResult_t ncclIbInit(void** ctx, uint64_t commId, ncclNetCommConfig_t* config INFO(NCCL_INIT|NCCL_NET, "NET/IB : Using%s %s; OOB %s:%s", line, ncclIbRelaxedOrderingEnabled ? "[RO]" : "", ncclIbIfName, ncclSocketToString(&ncclIbIfAddr, addrline)); + ncclIbUseInline = ncclParamIbUseInline(); + ncclIbGdrFlushDisable = ncclParamIbGdrFlushDisable(); + + rcclAinicRoce = ((rcclParamAinicRoce() == 1) ? true : false); + if (rcclAinicRoce) { + // for AINIC, these params are defaulted to enabled unless user forces it to disable(0). + rcclCtsInlineData = ((rcclParamCtsInlineData() == 0) ? false : true); + rcclCtsOffloadEnabled = ((rcclParamCtsOffloadEnabled() == 0) ? false : true); + // for AINIC IbUseInline is enabled by default always + ncclIbUseInline = true; + // for AINIC GDR flush is disabled by default + ncclIbGdrFlushDisable = 1; + + INFO(NCCL_INIT|NCCL_NET, "NET/IB : AINIC RoCEv2 optimizations enabled: CTS Inline Data: %s; CTS Offload: %s; " + "IB Use Inline: enabled; GDR Flush: disabled", rcclCtsInlineData ? "Enabled": "Disabled", + rcclCtsOffloadEnabled ? "Enabled": "Disabled"); + } } exit: ibContext.trafficClass = config->trafficClass; @@ -1271,6 +1322,8 @@ struct ncclIbListenComm { struct ncclIbCommStage stage; }; +#define MAX_INLINE_DATA_SIZE 24 + struct alignas(64) ncclIbSendFifo { uint64_t addr; uint64_t size; @@ -1281,10 +1334,21 @@ struct alignas(64) ncclIbSendFifo { char padding[16]; }; +struct alignas(32) ncclIbSendFifoCtsInline { + uint64_t addr; + uint32_t rkeys[1]; + int size; + uint8_t nreqs; + uint16_t tag; + uint32_t idx; + char padding[9]; +} __attribute__((packed)); + struct ncclIbQp { struct ibv_qp* qp; int devIndex; int remDevIdx; + int8_t ctsQpSlot; }; struct ncclIbRemSizesFifo { @@ -1331,6 +1395,7 @@ struct ncclIbSendComm { struct ncclIbNetCommBase base; // Start with fifo and ibv structs as they have alignment restrictions struct ncclIbSendFifo fifo[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS]; + struct ncclIbSendFifoCtsInline fifo_inline[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS]; struct ibv_sge sges[NCCL_NET_IB_MAX_RECVS]; struct ibv_send_wr wrs[NCCL_NET_IB_MAX_RECVS + 1]; // Each dev correlates to a mergedIbDev @@ -1346,6 +1411,7 @@ struct ncclIbSendComm { static_assert((sizeof(struct ncclIbNetCommBase) % 32) == 0, "ncclIbNetCommBase size must be 32-byte multiple to ensure fifo is at proper offset"); static_assert((offsetof(struct ncclIbSendComm, fifo) % 32) == 0, "ncclIbSendComm fifo must be 32-byte aligned"); static_assert((sizeof(struct ncclIbSendFifo) % 32) == 0, "ncclIbSendFifo element size must be 32-byte multiples"); +static_assert((sizeof(struct ncclIbSendFifoCtsInline) % 32) == 0, "ncclIbSendFifoCtsInline element size must be 32-byte multiples"); static_assert((offsetof(struct ncclIbSendComm, sges) % 32) == 0, "sges must be 32-byte aligned"); static_assert((offsetof(struct ncclIbSendComm, wrs) % 32) == 0, "wrs must be 32-byte aligned"); @@ -1360,6 +1426,7 @@ struct ncclIbGpuFlush { struct ncclIbRemFifo { struct ncclIbSendFifo elems[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS]; + struct ncclIbSendFifoCtsInline elems_cts_inline[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS]; uint64_t fifoTail; uint64_t addr; uint32_t flags; @@ -1415,20 +1482,59 @@ ncclResult_t ncclIbDestroyBase(struct ncclIbNetCommDevBase* base) { return ncclSuccess; } -ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbNetCommDevBase* base, int access_flags, void* qp_context, struct ncclIbQp* qp) { +ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbNetCommDevBase* base, + int access_flags, void* qp_context, struct ncclIbQp* qp, + int channel_id, bool data_qp, int8_t cts_qp_slot) { struct ibv_qp_init_attr qpInitAttr; + enum ncclIbChannelType channel_type = (data_qp ? ncclIbChannelTypeData : ncclIbChannelTypeCts); memset(&qpInitAttr, 0, sizeof(struct ibv_qp_init_attr)); qpInitAttr.qp_context = qp_context; qpInitAttr.send_cq = base->cq; qpInitAttr.recv_cq = base->cq; qpInitAttr.qp_type = IBV_QPT_RC; + + if (rcclAinicRoce) { + if (!nccl_channel_ud_map[channel_id][channel_type].udAllocated) { + bool lud = nccl_channel_last_ud[base->ibDevN][channel_type]; + nccl_channel_ud_map[channel_id][channel_type].udId = lud; + nccl_channel_ud_map[channel_id][channel_type].udAllocated = true; + nccl_channel_last_ud[base->ibDevN][channel_type] = + !(nccl_channel_last_ud[base->ibDevN][channel_type]); + } + if (nccl_channel_ud_map[channel_id][channel_type].udId) { + wrap_ionicdv_pd_set_udma_mask(base->pd, IONIC_UDMA_MASK_HIGH); + } else { + wrap_ionicdv_pd_set_udma_mask(base->pd, IONIC_UDMA_MASK_LOW); + } + qpInitAttr.sq_sig_all |= (1 << 16); + if (data_qp) { + qpInitAttr.sq_sig_all |= (1 << 17); + } else { + qpInitAttr.sq_sig_all &= (~(1 << 17)); + } + qpInitAttr.sq_sig_all |= (1 << 18); + + if (rcclCtsOffloadEnabled) { + qpInitAttr.sq_sig_all |= (1 << 19); + } else { + qpInitAttr.sq_sig_all &= (~(1 << 19)); + } + } + // We might send 2 messages per send (RDMA and RDMA_WITH_IMM) qpInitAttr.cap.max_send_wr = 2*MAX_REQUESTS; qpInitAttr.cap.max_recv_wr = MAX_REQUESTS; qpInitAttr.cap.max_send_sge = 1; qpInitAttr.cap.max_recv_sge = 1; - qpInitAttr.cap.max_inline_data = ncclParamIbUseInline() ? sizeof(struct ncclIbSendFifo) : 0; + if (rcclCtsInlineData) { + qpInitAttr.cap.max_inline_data = MAX_INLINE_DATA_SIZE; + } else { + qpInitAttr.cap.max_inline_data = ncclIbUseInline ? sizeof(struct ncclIbSendFifo) : 0; + } NCCLCHECK(wrap_ibv_create_qp(&qp->qp, base->pd, &qpInitAttr)); + if (rcclAinicRoce) { + NCCLCHECK(wrap_ionicdv_qp_set_gda(qp->qp, false, true)); + } struct ibv_qp_attr qpAttr; memset(&qpAttr, 0, sizeof(struct ibv_qp_attr)); qpAttr.qp_state = IBV_QPS_INIT; @@ -1438,6 +1544,9 @@ ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbNetCommDevBase* base, NCCLCHECK(wrap_ibv_modify_qp(qp->qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS)); TRACE(NCCL_NET, "NET/IB : ncclIbCreateQp port=%d dev=%d devName=%s ndevs=%d nmdevs=%d qpn=%u pkey=%u pd=%p", ib_port, base->ibDevN, ncclIbDevs[base->ibDevN].devName, ncclNIbDevs, ncclNMergedIbDevs, qp->qp->qp_num, qpAttr.pkey_index, base->pd); + if (rcclAinicRoce) { + qp->ctsQpSlot = cts_qp_slot; + } return ncclSuccess; } @@ -1521,7 +1630,7 @@ fail: goto exit; } -ncclResult_t ncclIbConnect(void* ctx, int dev, void* opaqueHandle, void** sendComm, ncclNetDeviceHandle_t** /*sendDevComm*/) { +ncclResult_t ncclIbConnect(void* ctx, int dev, void* opaqueHandle, void** sendComm, ncclNetDeviceHandle_t** sendDevComm) { ncclResult_t ret = ncclSuccess; struct ncclIbHandle* handle = (struct ncclIbHandle*) opaqueHandle; struct ncclIbCommStage* stage = &handle->stage; @@ -1529,8 +1638,13 @@ ncclResult_t ncclIbConnect(void* ctx, int dev, void* opaqueHandle, void** sendCo int ready; uint8_t link_layer = IBV_LINK_LAYER_UNSPECIFIED; int isP2p = 0; + int channel_id = 0; *sendComm = NULL; + if (rcclAinicRoce) { + channel_id = ((ncclNet_ctxt_t *)sendDevComm)->chId; + } + if (stage->state == ncclIbCommStateConnect) goto ib_connect_check; if (stage->state == ncclIbCommStateSendDevList) goto ib_send_dev_list; if (stage->state == ncclIbCommStateRecvDevList) goto ib_recv_dev_list; @@ -1612,7 +1726,7 @@ ib_recv_dev_list: for (int q = 0; q < comm->base.nqps; q++) { ncclIbSendCommDev* commDev = comm->devs + devIndex; ncclIbDev* ibDev = ncclIbDevs + commDev->base.ibDevN; - NCCLCHECKGOTO(ncclIbCreateQp(ibDev->portNum, &commDev->base, IBV_ACCESS_REMOTE_WRITE, &comm->base.stats, comm->base.qps + q), ret, fail); + NCCLCHECKGOTO(ncclIbCreateQp(ibDev->portNum, &commDev->base, IBV_ACCESS_REMOTE_WRITE, &comm->base.stats, comm->base.qps + q, channel_id, true, NCCL_CTS_QP_SLOT_INVALID), ret, fail); comm->base.qps[q].devIndex = devIndex; meta.qpInfo[q].qpn = comm->base.qps[q].qp->qp_num; meta.qpInfo[q].devIndex = comm->base.qps[q].devIndex; @@ -1637,7 +1751,11 @@ ib_recv_dev_list: devInfo->lid = ibDev->portAttr.lid; devInfo->ibv_dev_index = commDev->base.ibDevN; // Prepare my fifo - NCCLCHECKGOTO(wrap_ibv_reg_mr(&commDev->fifoMr, commDev->base.pd, comm->fifo, sizeof(struct ncclIbSendFifo)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_REMOTE_READ), ret, fail); + if (rcclCtsInlineData) { + NCCLCHECKGOTO(wrap_ibv_reg_mr(&commDev->fifoMr, commDev->base.pd, comm->fifo_inline, sizeof(struct ncclIbSendFifoCtsInline)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_REMOTE_READ), ret, fail); + } else { + NCCLCHECKGOTO(wrap_ibv_reg_mr(&commDev->fifoMr, commDev->base.pd, comm->fifo, sizeof(struct ncclIbSendFifo)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_REMOTE_READ), ret, fail); + } devInfo->fifoRkey = commDev->fifoMr->rkey; // Pack local GID info @@ -1680,7 +1798,11 @@ ib_recv_dev_list: } } config = (ncclNetCommConfig_t*)ctx; - meta.fifoAddr = (uint64_t)comm->fifo; + if (rcclCtsInlineData) { + meta.fifoAddr = (uint64_t)comm->fifo_inline; + } else { + meta.fifoAddr = (uint64_t)comm->fifo; + } meta.sl = (ncclParamIbSl() != -1) ? ncclParamIbSl() : (config && config->trafficClass != NCCL_NET_TRAFFIC_CLASS_UNDEF) ? config->trafficClass : NCCL_IB_SL_DEFAULT; meta.tc = (ncclParamIbTc() != -1) ? ncclParamIbTc() : (config && config->trafficClass != NCCL_NET_TRAFFIC_CLASS_UNDEF) ? config->trafficClass : NCCL_IB_TC_DEFAULT; strncpy(meta.devName, mergedDev->devName, MAX_MERGED_DEV_NAME); @@ -1825,18 +1947,22 @@ ncclResult_t ncclIbCheckVProps(ncclNetVDeviceProps_t* vProps1, ncclNetVDevicePro return ncclSuccess; } -NCCL_PARAM(IbGdrFlushDisable, "GDR_FLUSH_DISABLE", 0); RCCL_PARAM(IbGdrFlushGpuMemNoRelaxedOrdering, "GDR_FLUSH_GPU_MEM_NO_RELAXED_ORDERING", 1); -ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandle_t** /*recvDevComm*/) { +ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandle_t** recvDevComm) { ncclResult_t ret = ncclSuccess; struct ncclIbListenComm* lComm = (struct ncclIbListenComm*)listenComm; struct ncclIbCommStage* stage = &lComm->stage; struct ncclIbRecvComm* rComm = (struct ncclIbRecvComm*)stage->comm; int ready; int link_layer = IBV_LINK_LAYER_UNSPECIFIED; + int channel_id = 0; *recvComm = NULL; + if (rcclAinicRoce) { + channel_id = ((ncclNet_ctxt_t *) recvDevComm)->chId; + } + if (stage->state == ncclIbCommStateAccept) goto ib_accept_check; if (stage->state == ncclIbCommStateRecvDevList) goto ib_recv_dev_list; if (stage->state == ncclIbCommStateSendDevList) goto ib_send_dev_list; @@ -1966,7 +2092,7 @@ ib_recv: // Local ibDevN ibDevN = rComm->devs[devIndex].base.ibDevN; ibDev = ncclIbDevs + ibDevN; - NCCLCHECKGOTO(ncclIbCreateQp(ibDev->portNum, &rCommDev->base, IBV_ACCESS_REMOTE_WRITE, &rComm->base.stats, qp), ret, fail); + NCCLCHECKGOTO(ncclIbCreateQp(ibDev->portNum, &rCommDev->base, IBV_ACCESS_REMOTE_WRITE, &rComm->base.stats, qp, channel_id, false, q), ret, fail); qp->devIndex = devIndex; devIndex = (devIndex + 1) % rComm->base.vProps.ndevs; @@ -1992,16 +2118,22 @@ ib_recv: useDmaBuf = (ncclIbDmaBufSupport(lComm->dev) == ncclSuccess); rComm->flushEnabled = ((ncclIbGdrSupport() == ncclSuccess || useDmaBuf) - && (ncclParamIbGdrFlushDisable() == 0)) ? 1 : 0; + && (ncclIbGdrFlushDisable == 0)) ? 1 : 0; for (int i = 0; i < rComm->base.vProps.ndevs; i++) { rCommDev = rComm->devs + i; ibDev = ncclIbDevs + rCommDev->base.ibDevN; // Retain remote fifo info and prepare my RDMA ops rComm->remFifo.addr = remMeta.fifoAddr; - NCCLCHECKGOTO(wrap_ibv_reg_mr(&rCommDev->fifoMr, rCommDev->base.pd, &rComm->remFifo.elems, sizeof(struct ncclIbSendFifo)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_READ), ret, fail); + if (rcclCtsInlineData) { + NCCLCHECKGOTO(wrap_ibv_reg_mr(&rCommDev->fifoMr, rCommDev->base.pd, &rComm->remFifo.elems_cts_inline, + sizeof(struct ncclIbSendFifoCtsInline)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, + IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_READ), ret, fail); + } else { + NCCLCHECKGOTO(wrap_ibv_reg_mr(&rCommDev->fifoMr, rCommDev->base.pd, &rComm->remFifo.elems, sizeof(struct ncclIbSendFifo)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_READ), ret, fail); + } rCommDev->fifoSge.lkey = rCommDev->fifoMr->lkey; - if (ncclParamIbUseInline()) rComm->remFifo.flags = IBV_SEND_INLINE; + if (ncclIbUseInline) rComm->remFifo.flags = IBV_SEND_INLINE; // Allocate Flush dummy buffer for GPU Direct RDMA if (rComm->flushEnabled) { @@ -2039,7 +2171,7 @@ ib_recv: rCommDev->gpuFlush.sge.addr = (uint64_t)&rComm->gpuFlushHostMem; rCommDev->gpuFlush.sge.length = 1; rCommDev->gpuFlush.sge.lkey = rCommDev->gpuFlush.hostMr->lkey; - NCCLCHECKGOTO(ncclIbCreateQp(ibDev->portNum, &rCommDev->base, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_WRITE, &rComm->base.stats, &rCommDev->gpuFlush.qp), ret, fail); + NCCLCHECKGOTO(ncclIbCreateQp(ibDev->portNum, &rCommDev->base, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_WRITE, &rComm->base.stats, &rCommDev->gpuFlush.qp, channel_id, true, NCCL_CTS_QP_SLOT_INVALID), ret, fail); struct ncclIbDevInfo devInfo; devInfo.lid = ibDev->portAttr.lid; devInfo.link_layer = ibDev->portAttr.link_layer; @@ -2257,10 +2389,15 @@ ncclResult_t ncclIbDeregMr(void* comm, void* mhandle) { NCCL_PARAM(IbSplitDataOnQps, "IB_SPLIT_DATA_ON_QPS", 0); -ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { +ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot, bool use_write_op) { struct ncclIbRequest** reqs = comm->fifoReqs[slot]; volatile struct ncclIbSendFifo* slots = comm->fifo[slot]; - int nreqs = slots[0].nreqs; + int nreqs; + if (rcclCtsOffloadEnabled) { + nreqs = 1; + } else { + nreqs = slots[0].nreqs; + } if (nreqs > NCCL_NET_IB_MAX_RECVS) return ncclInternalError; uint64_t wr_id = 0ULL; @@ -2272,7 +2409,11 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { sge->addr=(uintptr_t)reqs[r]->send.data; wr->opcode = IBV_WR_RDMA_WRITE; wr->send_flags = 0; - wr->wr.rdma.remote_addr = slots[r].addr; + if (rcclCtsOffloadEnabled) { + wr->wr.rdma.remote_addr = 0xdeadbeef; + } else { + wr->wr.rdma.remote_addr = slots[r].addr; + } wr->next = wr + 1; wr_id += (reqs[r] - comm->base.reqs) << (r*8); #ifdef NCCL_ENABLE_NET_PROFILING @@ -2283,7 +2424,7 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { // Write size as immediate data. In the case of multi-send, only write // 0 or 1 as size to indicate whether there was data sent or received. uint32_t immData = 0; - if (nreqs == 1) { + if ((nreqs == 1) && (use_write_op == false)) { immData = reqs[0]->send.size; } else { int* sizes = comm->remSizesFifo.elems[slot]; @@ -2293,22 +2434,24 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { } struct ibv_send_wr* lastWr = comm->wrs+nreqs-1; - if (nreqs > 1 || (comm->ar && reqs[0]->send.size > ncclParamIbArThreshold())) { - // When using ADAPTIVE_ROUTING, send the bulk of the data first as an - // RDMA_WRITE, then a 0-byte RDMA_WRITE_WITH_IMM to trigger a remote - // completion. - lastWr++; - memset(lastWr, 0, sizeof(struct ibv_send_wr)); - if (nreqs > 1) { - // Write remote sizes Fifo - lastWr->wr.rdma.remote_addr = comm->remSizesFifo.addr + slot*NCCL_NET_IB_MAX_RECVS*sizeof(int); - lastWr->num_sge = 1; - lastWr->sg_list = &comm->remSizesFifo.sge; + if (use_write_op == false) { + if (nreqs > 1 || (comm->ar && reqs[0]->send.size > ncclParamIbArThreshold())) { + // When using ADAPTIVE_ROUTING, send the bulk of the data first as an + // RDMA_WRITE, then a 0-byte RDMA_WRITE_WITH_IMM to trigger a remote + // completion. + lastWr++; + memset(lastWr, 0, sizeof(struct ibv_send_wr)); + if (nreqs > 1) { + // Write remote sizes Fifo + lastWr->wr.rdma.remote_addr = comm->remSizesFifo.addr + slot*NCCL_NET_IB_MAX_RECVS*sizeof(int); + lastWr->num_sge = 1; + lastWr->sg_list = &comm->remSizesFifo.sge; + } } + lastWr->opcode = IBV_WR_RDMA_WRITE_WITH_IMM; + lastWr->imm_data = immData; } lastWr->wr_id = wr_id; - lastWr->opcode = IBV_WR_RDMA_WRITE_WITH_IMM; - lastWr->imm_data = immData; lastWr->next = NULL; lastWr->send_flags = IBV_SEND_SIGNALED; @@ -2324,7 +2467,11 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { //ncclIbAddEvent(reqs[r], devIndex, &comm->devs[devIndex].base); // Select proper rkey (needed even for 0-size send) - comm->wrs[r].wr.rdma.rkey = slots[r].rkeys[qp->remDevIdx]; + if (rcclCtsOffloadEnabled) { + comm->wrs[r].wr.rdma.rkey = 0xbade; + } else { + comm->wrs[r].wr.rdma.rkey = slots[r].rkeys[qp->remDevIdx]; + } int chunkSize = DIVUP(DIVUP(reqs[r]->send.size, nqps), align) * align; int length = std::min(reqs[r]->send.size-reqs[r]->send.offset, chunkSize); @@ -2340,7 +2487,7 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { } } - if (nreqs > 1) { + if ((use_write_op == false) && (nreqs > 1)) { // Also make sure lastWr writes remote sizes using the right lkey comm->remSizesFifo.sge.lkey = comm->remSizesFifo.mrs[devIndex]->lkey; lastWr->wr.rdma.rkey = comm->remSizesFifo.rkeys[devIndex]; @@ -2398,32 +2545,46 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, size_t size, int tag, void* NCCLCHECK(ncclIbStatsCheckFatalCount(&comm->base.stats,__func__)); struct ncclIbMrHandle* mhandleWrapper = (struct ncclIbMrHandle*) mhandle; + bool use_write_op = false; + if (rcclAinicRoce) { + use_write_op = (*request == (void *)NCCL_NET_OPTIONAL_RECV_COMPLETION) ? true : false; + } // Wait for the receiver to have posted the corresponding receive int nreqs = 0; volatile struct ncclIbSendFifo* slots; + if (rcclCtsOffloadEnabled) { + nreqs = 1; + } + int slot = (comm->fifoHead) % MAX_REQUESTS; struct ncclIbRequest** reqs = comm->fifoReqs[slot]; - slots = comm->fifo[slot]; - uint64_t idx = comm->fifoHead+1; - if (slots[0].idx != idx) { *request = NULL; return ncclSuccess; } - nreqs = slots[0].nreqs; - // Wait until all data has arrived - for (int r=1; rfifo[slot]; + uint64_t idx = comm->fifoHead+1; + if (slots[0].idx != idx) { *request = NULL; return ncclSuccess; } + nreqs = slots[0].nreqs; + // Wait until all data has arrived + for (int r=1; r slots[r].size) size = slots[r].size; - // Sanity checks - if (slots[r].size < 0 || slots[r].addr == 0 || slots[r].rkeys[0] == 0) { - char line[SOCKET_NAME_MAXLEN + 1]; - union ncclSocketAddress addr; - ncclSocketGetAddr(&comm->base.sock, &addr); - WARN("NET/IB : req %d/%d tag %x peer %s posted incorrect receive info: size %ld addr %lx rkeys[0]=%x", - r, nreqs, tag, ncclSocketToString(&addr, line), slots[r].size, slots[r].addr, slots[r].rkeys[0]); - return ncclInternalError; + if (!rcclCtsOffloadEnabled) { + if (reqs[r] != NULL || slots[r].tag != tag) continue; + + if (size > slots[r].size) size = slots[r].size; + // Sanity checks + if (slots[r].size < 0 || slots[r].addr == 0 || slots[r].rkeys[0] == 0) { + char line[SOCKET_NAME_MAXLEN + 1]; + union ncclSocketAddress addr; + ncclSocketGetAddr(&comm->base.sock, &addr); + WARN("NET/IB : req %d/%d tag %x peer %s posted incorrect receive info: size %ld addr %lx rkeys[0]=%x", + r, nreqs, tag, ncclSocketToString(&addr, line), slots[r].size, slots[r].addr, slots[r].rkeys[0]); + return ncclInternalError; + } + } else{ + if (reqs[r] != NULL) continue; } struct ncclIbRequest* req; @@ -2467,10 +2628,12 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, size_t size, int tag, void* } TIME_START(0); - NCCLCHECK(ncclIbMultiSend(comm, slot)); + NCCLCHECK(ncclIbMultiSend(comm, slot, use_write_op)); // Clear slots[0]->nreqs, as well as other fields to help debugging and sanity checks - memset((void*)slots, 0, sizeof(struct ncclIbSendFifo)); + if (!rcclCtsOffloadEnabled) { + memset((void*)slots, 0, sizeof(struct ncclIbSendFifo)); + } memset(reqs, 0, NCCL_NET_IB_MAX_RECVS*sizeof(struct ncclIbRequest*)); comm->fifoHead++; TIME_STOP(0); @@ -2483,30 +2646,60 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, size_t size, int tag, void* ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, size_t* sizes, int* tags, void** mhandles, struct ncclIbRequest* req) { struct ibv_send_wr wr; + struct ncclIbSendFifo* localElem = NULL; + struct ncclIbSendFifoCtsInline* localElemCtsInline = NULL; + uint64_t localElemRef; + int qpIndex = 0; + ncclIbQp* ctsQp = NULL; memset(&wr, 0, sizeof(wr)); int slot = comm->remFifo.fifoTail%MAX_REQUESTS; req->recv.sizes = comm->sizesFifo[slot]; for (int i=0; irecv.sizes[i] = 0; - struct ncclIbSendFifo* localElem = comm->remFifo.elems[slot]; + if (rcclCtsInlineData) { + localElemCtsInline = comm->remFifo.elems_cts_inline[slot]; + } else { + localElem = comm->remFifo.elems[slot]; + } - // Select the next devIndex (local) and QP to use for posting this CTS message - // Since QPs are initialized by striping across devIndex, we can simply assign this to the same value - ncclIbQp* ctsQp = comm->base.qps + comm->base.devIndex; - comm->base.devIndex = (comm->base.devIndex + 1) % comm->base.vProps.ndevs; + if (rcclAinicRoce) { + qpIndex = comm->base.qpIndex; + ctsQp = comm->base.qps + qpIndex; + } else { + // Select the next devIndex (local) and QP to use for posting this CTS message + // Since QPs are initialized by striping across devIndex, we can simply assign this to the same value + ctsQp = comm->base.qps + comm->base.devIndex; + comm->base.devIndex = (comm->base.devIndex + 1) % comm->base.vProps.ndevs; + } for (int i=0; ibase.vProps.ndevs; j++) + localElemCtsInline[i].rkeys[j] = mhandleWrapper->mrs[j]->rkey; + + localElemCtsInline[i].nreqs = n; + localElemCtsInline[i].size = sizes[i]; // Sanity/Debugging + localElemCtsInline[i].tag = tags[i]; + localElemCtsInline[i].idx = comm->remFifo.fifoTail+1; + localElemRef = (uint64_t)localElemCtsInline; + + } else { + localElem[i].addr = (uint64_t)data[i]; - // Send all applicable rkeys - for (int j = 0; j < comm->base.vProps.ndevs; j++) - localElem[i].rkeys[j] = mhandleWrapper->mrs[j]->rkey; + // Send all applicable rkeys + for (int j = 0; j < comm->base.vProps.ndevs; j++) + localElem[i].rkeys[j] = mhandleWrapper->mrs[j]->rkey; - localElem[i].nreqs = n; - localElem[i].size = sizes[i]; // Sanity/Debugging - localElem[i].tag = tags[i]; - localElem[i].idx = comm->remFifo.fifoTail+1; + localElem[i].nreqs = n; + localElem[i].size = sizes[i]; // Sanity/Debugging + localElem[i].tag = tags[i]; + localElem[i].idx = comm->remFifo.fifoTail+1; + localElemRef = (uint64_t)localElem; + } } wr.wr.rdma.remote_addr = comm->remFifo.addr + slot*NCCL_NET_IB_MAX_RECVS*sizeof(struct ncclIbSendFifo); @@ -2514,8 +2707,12 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, siz wr.wr.rdma.rkey = comm->base.remDevs[ctsQp->remDevIdx].fifoRkey; // Set the correct sge properties - comm->devs[ctsQp->devIndex].fifoSge.addr = (uint64_t)localElem; - comm->devs[ctsQp->devIndex].fifoSge.length = n*sizeof(struct ncclIbSendFifo); + comm->devs[ctsQp->devIndex].fifoSge.addr = localElemRef; + if (rcclCtsInlineData) { + comm->devs[ctsQp->devIndex].fifoSge.length = MAX_INLINE_DATA_SIZE; + } else { + comm->devs[ctsQp->devIndex].fifoSge.length = n*sizeof(struct ncclIbSendFifo); + } wr.sg_list = &comm->devs[ctsQp->devIndex].fifoSge; wr.num_sge = 1; @@ -2545,7 +2742,13 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, siz // // slot == devIndex - When writing to fifo slot N, and this QP lives on device index N, it should send signalled. // This works out that each fifo posting QP gets drained - if (slot == ctsQp->devIndex) { + if (rcclAinicRoce) { + if (slot == ctsQp->ctsQpSlot) { + wr.send_flags |= IBV_SEND_SIGNALED; + wr.wr_id = req - comm->base.reqs; + ncclIbAddEvent(req, ctsQp->devIndex, &comm->devs[ctsQp->devIndex].base); + } + } else if (slot == ctsQp->devIndex) { wr.send_flags |= IBV_SEND_SIGNALED; wr.wr_id = req - comm->base.reqs; ncclIbAddEvent(req, ctsQp->devIndex, &comm->devs[ctsQp->devIndex].base); @@ -2560,10 +2763,16 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, siz comm->remFifo.fifoTail++; + if (rcclAinicRoce) { + // Select the next qpIndex + comm->base.qpIndex = (comm->base.qpIndex+1) % comm->base.nqps; + } return ncclSuccess; } ncclResult_t ncclIbIrecv(void* recvComm, int n, void** data, size_t* sizes, int* tags, void** mhandles, void** phandles, void** request) { + ncclResult_t res = ncclSuccess; + bool netOptRecvCompletionEnabled = false; struct ncclIbRecvComm* comm = (struct ncclIbRecvComm*)recvComm; if (comm->base.ready == 0) { WARN("NET/IB: ncclIbIrecv() called when comm->base.ready == 0"); @@ -2573,6 +2782,11 @@ ncclResult_t ncclIbIrecv(void* recvComm, int n, void** data, size_t* sizes, int* if (n > NCCL_NET_IB_MAX_RECVS) return ncclInternalError; NCCLCHECK(ncclIbStatsCheckFatalCount(&comm->base.stats,__func__)); + if (rcclAinicRoce) { + if (*request == (void *) NCCL_NET_OPTIONAL_RECV_COMPLETION) { + netOptRecvCompletionEnabled = true; + } + } struct ncclIbRequest* req; NCCLCHECK(ncclIbGetRequest(&comm->base, &req)); req->type = NCCL_NET_IB_REQ_RECV; @@ -2586,50 +2800,64 @@ ncclResult_t ncclIbIrecv(void* recvComm, int n, void** data, size_t* sizes, int* req->devBases[i] = &comm->devs[i].base; } - struct ibv_recv_wr wr; - memset(&wr, 0, sizeof(wr)); - wr.wr_id = req - comm->base.reqs; - wr.sg_list = NULL; - wr.num_sge = 0; + if (!netOptRecvCompletionEnabled) { + struct ibv_recv_wr wr; + memset(&wr, 0, sizeof(wr)); + wr.wr_id = req - comm->base.reqs; + wr.sg_list = NULL; + wr.num_sge = 0; - TIME_START(1); - // Select either all QPs, or one qp per-device - const int nqps = ncclParamIbSplitDataOnQps() ? comm->base.nqps : comm->base.nDataQps; + TIME_START(1); + // Select either all QPs, or one qp per-device + const int nqps = ncclParamIbSplitDataOnQps() ? comm->base.nqps : comm->base.nDataQps; - // Post recvs - struct ibv_recv_wr* bad_wr; - for (int i = 0; i < nqps; i++) { - struct ncclIbQp* qp = comm->base.qps + comm->base.qpIndex; - ncclIbAddEvent(req, qp->devIndex, &comm->devs[qp->devIndex].base); + // Post recvs + struct ibv_recv_wr* bad_wr; + int qpIndex = comm->base.qpIndex; + for (int i = 0; i < nqps; i++) { + struct ncclIbQp* qp = comm->base.qps + comm->base.qpIndex; + ncclIbAddEvent(req, qp->devIndex, &comm->devs[qp->devIndex].base); #ifdef NCCL_ENABLE_NET_PROFILING - // Start a QP event for every request in the multirecv and every qp - for (int r = 0; r < n; r++) { - int nEventHandles = req->pInfo[r].nEventHandles; - assert(nEventHandles < MAX_QPS_PER_REQ); - req->pInfo[r].qpIndex[nEventHandles] = comm->base.qpIndex; - // Store info for profiler - int64_t pluginId = NCCL_PROFILER_NET_TYPE_IB | NCCL_PROFILER_NET_IB_VER; - req->pInfo[r].data.type = ncclProfileQp; - req->pInfo[r].data.qp.device = qp->devIndex; - req->pInfo[r].data.qp.wr_id = wr.wr_id; - req->pInfo[r].data.qp.qpNum = qp->qp->qp_num; - NCCLCHECK(ncclProfilerFunction(&req->pInfo[r].qpEventHandles[nEventHandles], ncclProfilerNetEventStart, phandles[r], pluginId, &req->pInfo[r].data)); - req->pInfo[r].nEventHandles++; - } + // Start a QP event for every request in the multirecv and every qp + for (int r = 0; r < n; r++) { + int nEventHandles = req->pInfo[r].nEventHandles; + assert(nEventHandles < MAX_QPS_PER_REQ); + req->pInfo[r].qpIndex[nEventHandles] = comm->base.qpIndex; + // Store info for profiler + int64_t pluginId = NCCL_PROFILER_NET_TYPE_IB | NCCL_PROFILER_NET_IB_VER; + req->pInfo[r].data.type = ncclProfileQp; + req->pInfo[r].data.qp.device = qp->devIndex; + req->pInfo[r].data.qp.wr_id = wr.wr_id; + req->pInfo[r].data.qp.qpNum = qp->qp->qp_num; + NCCLCHECK(ncclProfilerFunction(&req->pInfo[r].qpEventHandles[nEventHandles], ncclProfilerNetEventStart, phandles[r], pluginId, &req->pInfo[r].data)); + req->pInfo[r].nEventHandles++; + } #endif - NCCLCHECK(wrap_ibv_post_recv(qp->qp, &wr, &bad_wr)); - comm->base.qpIndex = (comm->base.qpIndex+1)%comm->base.nqps; - } + NCCLCHECKGOTO(wrap_ibv_post_recv(qp->qp, &wr, &bad_wr), res, err); + // Don't update comm->base.qpIndex yet, we need to run through this same set of QPs + // inside ncclIbPostFifo() + if (rcclAinicRoce) { + qpIndex = (qpIndex+1)%comm->base.nqps; + } else { + comm->base.qpIndex = (comm->base.qpIndex+1)%comm->base.nqps; + } + } - TIME_STOP(1); + TIME_STOP(1); + } // netOptRecvCompletionEnabled = false // Post to FIFO to notify sender TIME_START(2); - NCCLCHECK(ncclIbPostFifo(comm, n, data, sizes, tags, mhandles, req)); + NCCLCHECKGOTO(ncclIbPostFifo(comm, n, data, sizes, tags, mhandles, req), res, err); TIME_STOP(2); *request = req; return ncclSuccess; +err: + if (req) { + ncclIbFreeRequest(req); + } + return res; } ncclResult_t ncclIbIflush(void* recvComm, int n, void** data, int* sizes, void** mhandles, void** request) { @@ -2698,6 +2926,8 @@ static int getReqQpIndex(struct ncclIbRequest* req, int request, int qpNumber) { } #endif +#define NCCL_CQ_POLL_MAX_EVENT 16 + ncclResult_t ncclIbTest(void* request, int* done, int* sizes) { struct ncclIbRequest *r = (struct ncclIbRequest*)request; *done = 0; @@ -2731,13 +2961,18 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) { int totalWrDone = 0; int wrDone = 0; - struct ibv_wc wcs[4]; + struct ibv_wc wcs[NCCL_CQ_POLL_MAX_EVENT]; + int cqMaxPollEvent = 4; + if (rcclAinicRoce) { + cqMaxPollEvent = NCCL_CQ_POLL_MAX_EVENT; + } for (int i = 0; i < NCCL_IB_MAX_DEVS_PER_NIC; i++) { TIME_START(3); // If we expect any completions from this device's CQ if (r->events[i]) { - NCCLCHECK(wrap_ibv_poll_cq(r->devBases[i]->cq, 4, wcs, &wrDone)); + NCCLCHECK(wrap_ibv_poll_cq(r->devBases[i]->cq, cqMaxPollEvent, + wcs, &wrDone)); totalWrDone += wrDone; if (wrDone == 0) { TIME_CANCEL(3); } else { TIME_STOP(3); } if (wrDone == 0) continue; @@ -2889,7 +3124,7 @@ ncclResult_t rcclNetP2pPolicy(void* handle, int isP2p) { } ncclNet_t ncclNetIb = { - "IB", + "ROCM-IB", ncclIbInit, ncclIbDevices, ncclIbGetProperties,